Merge pull request #4 from spulec/master

pull latest changes from upstream spulec/moto repo
This commit is contained in:
Jon Beilke 2020-01-29 09:57:22 -06:00 committed by GitHub
commit d4851d3eab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
564 changed files with 99088 additions and 41467 deletions

2
.gitignore vendored
View File

@ -15,6 +15,8 @@ python_env
.ropeproject/ .ropeproject/
.pytest_cache/ .pytest_cache/
venv/ venv/
env/
.python-version .python-version
.vscode/ .vscode/
tests/file.tmp tests/file.tmp
.eggs/

View File

@ -1,37 +1,66 @@
dist: xenial dist: bionic
language: python language: python
sudo: false
services: services:
- docker - docker
python: python:
- 2.7 - 2.7
- 3.6 - 3.6
- 3.7 - 3.7
- 3.8
env: env:
- TEST_SERVER_MODE=false - TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true - TEST_SERVER_MODE=true
before_install: before_install:
- export BOTO_CONFIG=/dev/null - export BOTO_CONFIG=/dev/null
install: install:
# We build moto first so the docker container doesn't try to compile it as well, also note we don't use - |
# -d for docker run so the logs show up in travis python setup.py sdist
# Python images come from here: https://hub.docker.com/_/python/
- |
python setup.py sdist
if [ "$TEST_SERVER_MODE" = "true" ]; then if [ "$TEST_SERVER_MODE" = "true" ]; then
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh & if [ "$TRAVIS_PYTHON_VERSION" = "3.8" ]; then
# Python 3.8 does not provide Stretch images yet [1]
# [1] https://github.com/docker-library/python/issues/428
PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-buster
else
PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-stretch
fi fi
travis_retry pip install boto==2.45.0 docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh &
travis_retry pip install boto3 fi
travis_retry pip install dist/moto*.gz travis_retry pip install boto==2.45.0
travis_retry pip install coveralls==1.1 travis_retry pip install boto3
travis_retry pip install -r requirements-dev.txt travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt
if [ "$TEST_SERVER_MODE" = "true" ]; then if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py python wait_for.py
fi fi
before_script:
- if [[ $TRAVIS_PYTHON_VERSION == "3.7" ]]; then make lint; fi
script: script:
- make test - make test-only
after_success: after_success:
- coveralls - coveralls
before_deploy:
- git checkout $TRAVIS_BRANCH
- git fetch --unshallow
- python update_version_from_git.py
deploy:
- provider: pypi
distributions: sdist bdist_wheel
user: spulec
password:
secure: NxnPylnTfekJmGyoufCw0lMoYRskSMJzvAIyAlJJVYKwEhmiCPOrdy5qV8i8mRZ1AkUsqU3jBZ/PD56n96clHW0E3d080UleRDj6JpyALVdeLfMqZl9kLmZ8bqakWzYq3VSJKw2zGP/L4tPGf8wTK1SUv9yl/YNDsBdCkjDverw=
on:
branch:
- master
skip_cleanup: true
skip_existing: true
# - provider: pypi
# distributions: sdist bdist_wheel
# user: spulec
# password:
# secure: NxnPylnTfekJmGyoufCw0lMoYRskSMJzvAIyAlJJVYKwEhmiCPOrdy5qV8i8mRZ1AkUsqU3jBZ/PD56n96clHW0E3d080UleRDj6JpyALVdeLfMqZl9kLmZ8bqakWzYq3VSJKw2zGP/L4tPGf8wTK1SUv9yl/YNDsBdCkjDverw=
# on:
# tags: true
# skip_existing: true

View File

@ -54,4 +54,7 @@ Moto is written by Steve Pulec with contributions from:
* [William Richard](https://github.com/william-richard) * [William Richard](https://github.com/william-richard)
* [Alex Casalboni](https://github.com/alexcasalboni) * [Alex Casalboni](https://github.com/alexcasalboni)
* [Jon Beilke](https://github.com/jrbeilke) * [Jon Beilke](https://github.com/jrbeilke)
* [Bendeguz Acs](https://github.com/acsbendi)
* [Craig Anderson](https://github.com/craiga)
* [Robert Lewis](https://github.com/ralewis85) * [Robert Lewis](https://github.com/ralewis85)
* [Kyle Jones](https://github.com/Kerl1310)

View File

@ -1,6 +1,189 @@
Moto Changelog Moto Changelog
=================== ===================
1.3.14
-----
General Changes:
* Support for Python 3.8
* Linting: Black is now enforced.
New Services:
* Athena
* Config
* DataSync
* Step Functions
New methods:
* Athena:
* create_work_group()
* list_work_groups()
* API Gateway:
* delete_stage()
* update_api_key()
* CloudWatch Logs
* list_tags_log_group()
* tag_log_group()
* untag_log_group()
* Config
* batch_get_resource_config()
* delete_aggregation_authorization()
* delete_configuration_aggregator()
* describe_aggregation_authorizations()
* describe_configuration_aggregators()
* get_resource_config_history()
* list_aggregate_discovered_resources() (For S3)
* list_discovered_resources() (For S3)
* put_aggregation_authorization()
* put_configuration_aggregator()
* Cognito
* assume_role_with_web_identity()
* describe_identity_pool()
* get_open_id_token()
* update_user_pool_domain()
* DataSync:
* cancel_task_execution()
* create_location()
* create_task()
* start_task_execution()
* EC2:
* create_launch_template()
* create_launch_template_version()
* describe_launch_template_versions()
* describe_launch_templates()
* ECS
* decrypt()
* encrypt()
* generate_data_key_without_plaintext()
* generate_random()
* re_encrypt()
* Glue
* batch_get_partition()
* IAM
* create_open_id_connect_provider()
* create_virtual_mfa_device()
* delete_account_password_policy()
* delete_open_id_connect_provider()
* delete_policy()
* delete_virtual_mfa_device()
* get_account_password_policy()
* get_open_id_connect_provider()
* list_open_id_connect_providers()
* list_virtual_mfa_devices()
* update_account_password_policy()
* Lambda
* create_event_source_mapping()
* delete_event_source_mapping()
* get_event_source_mapping()
* list_event_source_mappings()
* update_configuration()
* update_event_source_mapping()
* update_function_code()
* KMS
* decrypt()
* encrypt()
* generate_data_key_without_plaintext()
* generate_random()
* re_encrypt()
* SES
* send_templated_email()
* SNS
* add_permission()
* list_tags_for_resource()
* remove_permission()
* tag_resource()
* untag_resource()
* SSM
* describe_parameters()
* get_parameter_history()
* Step Functions
* create_state_machine()
* delete_state_machine()
* describe_execution()
* describe_state_machine()
* describe_state_machine_for_execution()
* list_executions()
* list_state_machines()
* list_tags_for_resource()
* start_execution()
* stop_execution()
SQS
* list_queue_tags()
* send_message_batch()
General updates:
* API Gateway:
* Now generates valid IDs
* API Keys, Usage Plans now support tags
* ACM:
* list_certificates() accepts the status parameter
* Batch:
* submit_job() can now be called with job name
* CloudWatch Events
* Multi-region support
* CloudWatch Logs
* get_log_events() now supports pagination
* Cognito:
* Now throws UsernameExistsException for known users
* DynamoDB
* update_item() now supports lists, the list_append-operator and removing nested items
* delete_item() now supports condition expressions
* get_item() now supports projection expression
* Enforces 400KB item size
* Validation on duplicate keys in batch_get_item()
* Validation on AttributeDefinitions on create_table()
* Validation on Query Key Expression
* Projection Expressions now support nested attributes
* EC2:
* Change DesiredCapacity behaviour for AutoScaling groups
* Extend list of supported EC2 ENI properties
* Create ASG from Instance now supported
* ASG attached to a terminated instance now recreate the instance of required
* Unify OwnerIDs
* ECS
* Task definition revision deregistration: remaining revisions now remain unchanged
* Fix created_at/updated_at format for deployments
* Support multiple regions
* ELB
* Return correct response then describing target health of stopped instances
* Target groups now longer show terminated instances
* 'fixed-response' now a supported action-type
* Now supports redirect: authenticate-cognito
* Kinesis FireHose
* Now supports ExtendedS3DestinationConfiguration
* KMS
* Now supports tags
* Organizations
* create_organization() now creates Master account
* Redshift
* Fix timezone problems when creating a cluster
* Support for enhanced_vpc_routing-parameter
* Route53
* Implemented UPSERT for change_resource_records
* S3:
* Support partNumber for head_object
* Support for INTELLIGENT_TIERING, GLACIER and DEEP_ARCHIVE
* Fix KeyCount attribute
* list_objects now supports pagination (next_marker)
* Support tagging for versioned objects
* STS
* Implement validation on policy length
* Lambda
* Support EventSourceMappings for SQS, DynamoDB
* get_function(), delete_function() now both support ARNs as parameters
* IAM
* Roles now support tags
* Policy Validation: SID can be empty
* Validate roles have no attachments when deleting
* SecretsManager
* Now supports binary secrets
* IOT
* update_thing_shadow validation
* delete_thing now also removed principals
* SQS
* Tags supported for create_queue()
1.3.7 1.3.7
----- -----

120
CONFIG_README.md Normal file
View File

@ -0,0 +1,120 @@
# AWS Config Querying Support in Moto
An experimental feature for AWS Config has been developed to provide AWS Config capabilities in your unit tests.
This feature is experimental as there are many services that are not yet supported and will require the community to add them in
over time. This page details how the feature works and how you can use it.
## What is this and why would I use this?
AWS Config is an AWS service that describes your AWS resource types and can track their changes over time. At this time, moto does not
have support for handling the configuration history changes, but it does have a few methods mocked out that can be immensely useful
for unit testing.
If you are developing automation that needs to pull against AWS Config, then this will help you write tests that can simulate your
code in production.
## How does this work?
The AWS Config capabilities in moto work by examining the state of resources that are created within moto, and then returning that data
in the way that AWS Config would return it (sans history). This will work by querying all of the moto backends (regions) for a given
resource type.
However, this will only work on resource types that have this enabled.
### Current enabled resource types:
1. S3
## Developer Guide
There are several pieces to this for adding new capabilities to moto:
1. Listing resources
1. Describing resources
For both, there are a number of pre-requisites:
### Base Components
In the `moto/core/models.py` file is a class named `ConfigQueryModel`. This is a base class that keeps track of all the
resource type backends.
At a minimum, resource types that have this enabled will have:
1. A `config.py` file that will import the resource type backends (from the `__init__.py`)
1. In the resource's `config.py`, an implementation of the `ConfigQueryModel` class with logic unique to the resource type
1. An instantiation of the `ConfigQueryModel`
1. In the `moto/config/models.py` file, import the `ConfigQueryModel` instantiation, and update `RESOURCE_MAP` to have a mapping of the AWS Config resource type
to the instantiation on the previous step (just imported).
An example of the above is implemented for S3. You can see that by looking at:
1. `moto/s3/config.py`
1. `moto/config/models.py`
As well as the corresponding unit tests in:
1. `tests/s3/test_s3.py`
1. `tests/config/test_config.py`
Note for unit testing, you will want to add a test to ensure that you can query all the resources effectively. For testing this feature,
the unit tests for the `ConfigQueryModel` will not make use of `boto` to create resources, such as S3 buckets. You will need to use the
backend model methods to provision the resources. This is to make tests compatible with the moto server. You should absolutely make tests
in the resource type to test listing and object fetching.
### Listing
S3 is currently the model implementation, but it also odd in that S3 is a global resource type with regional resource residency.
But for most resource types the following is true:
1. There are regional backends with their own sets of data
1. Config aggregation can pull data from any backend region -- we assume that everything lives in the same account
Implementing the listing capability will be different for each resource type. At a minimum, you will need to return a `List` of `Dict`s
that look like this:
```python
[
{
'type': 'AWS::The AWS Config data type',
'name': 'The name of the resource',
'id': 'The ID of the resource',
'region': 'The region of the resource -- if global, then you may want to have the calling logic pass in the
aggregator region in for the resource region -- or just us-east-1 :P'
}
, ...
]
```
It's recommended to read the comment for the `ConfigQueryModel`'s `list_config_service_resources` function in [base class here](moto/core/models.py).
^^ The AWS Config code will see this and format it correct for both aggregated and non-aggregated calls.
#### General implementation tips
The aggregation and non-aggregation querying can and should just use the same overall logic. The differences are:
1. Non-aggregated listing will specify the region-name of the resource backend `backend_region`
1. Aggregated listing will need to be able to list resource types across ALL backends and filter optionally by passing in `resource_region`.
An example of a working implementation of this is [S3](moto/s3/config.py).
Pagination should generally be able to pull out the resource across any region so should be sharded by `region-item-name` -- not done for S3
because S3 has a globally unique name space.
### Describing Resources
Fetching a resource's configuration has some similarities to listing resources, but it requires more work (to implement). Due to the
various ways that a resource can be configured, some work will need to be done to ensure that the Config dict returned is correct.
For most resource types the following is true:
1. There are regional backends with their own sets of data
1. Config aggregation can pull data from any backend region -- we assume that everything lives in the same account
The current implementation is for S3. S3 is very complex and depending on how the bucket is configured will depend on what Config will
return for it.
When implementing resource config fetching, you will need to return at a minimum `None` if the resource is not found, or a `dict` that looks
like what AWS Config would return.
It's recommended to read the comment for the `ConfigQueryModel` 's `get_config_resource` function in [base class here](moto/core/models.py).

View File

@ -2,6 +2,10 @@
Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project. Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project.
## Running the tests locally
Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
## Is there a missing feature? ## Is there a missing feature?
Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services. Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services.

File diff suppressed because it is too large Load Diff

View File

@ -10,16 +10,20 @@ endif
init: init:
@python setup.py develop @python setup.py develop
@pip install -r requirements.txt @pip install -r requirements-dev.txt
lint: lint:
flake8 moto flake8 moto
black --check moto/ tests/
test: lint test-only:
rm -f .coverage rm -f .coverage
rm -rf cover rm -rf cover
@nosetests -sv --with-coverage --cover-html ./tests/ $(TEST_EXCLUDE) @nosetests -sv --with-coverage --cover-html ./tests/ $(TEST_EXCLUDE)
test: lint test-only
test_server: test_server:
@TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/ @TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/
@ -27,7 +31,8 @@ aws_managed_policies:
scripts/update_managed_policies.py scripts/update_managed_policies.py
upload_pypi_artifact: upload_pypi_artifact:
python setup.py sdist bdist_wheel upload python setup.py sdist bdist_wheel
twine upload dist/*
push_dockerhub_image: push_dockerhub_image:
docker build -t motoserver/moto . docker build -t motoserver/moto .

328
README.md
View File

@ -5,8 +5,11 @@
[![Build Status](https://travis-ci.org/spulec/moto.svg?branch=master)](https://travis-ci.org/spulec/moto) [![Build Status](https://travis-ci.org/spulec/moto.svg?branch=master)](https://travis-ci.org/spulec/moto)
[![Coverage Status](https://coveralls.io/repos/spulec/moto/badge.svg?branch=master)](https://coveralls.io/r/spulec/moto) [![Coverage Status](https://coveralls.io/repos/spulec/moto/badge.svg?branch=master)](https://coveralls.io/r/spulec/moto)
[![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org) [![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org)
![PyPI](https://img.shields.io/pypi/v/moto.svg)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
# In a nutshell ## In a nutshell
Moto is a library that allows your tests to easily mock out AWS Services. Moto is a library that allows your tests to easily mock out AWS Services.
@ -47,7 +50,7 @@ def test_my_model_save():
body = conn.Object('mybucket', 'steve').get()['Body'].read().decode("utf-8") body = conn.Object('mybucket', 'steve').get()['Body'].read().decode("utf-8")
assert body == b'is awesome' assert body == 'is awesome'
``` ```
With the decorator wrapping the test, all the calls to s3 are automatically mocked out. The mock keeps the state of the buckets and keys. With the decorator wrapping the test, all the calls to s3 are automatically mocked out. The mock keeps the state of the buckets and keys.
@ -55,95 +58,96 @@ With the decorator wrapping the test, all the calls to s3 are automatically mock
It gets even better! Moto isn't just for Python code and it isn't just for S3. Look at the [standalone server mode](https://github.com/spulec/moto#stand-alone-server-mode) for more information about running Moto with other languages. Here's the status of the other AWS services implemented: It gets even better! Moto isn't just for Python code and it isn't just for S3. Look at the [standalone server mode](https://github.com/spulec/moto#stand-alone-server-mode) for more information about running Moto with other languages. Here's the status of the other AWS services implemented:
```gherkin ```gherkin
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Service Name | Decorator | Development Status | | Service Name | Decorator | Development Status |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| ACM | @mock_acm | all endpoints done | | ACM | @mock_acm | all endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| API Gateway | @mock_apigateway | core endpoints done | | API Gateway | @mock_apigateway | core endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Autoscaling | @mock_autoscaling| core endpoints done | | Autoscaling | @mock_autoscaling | core endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Cloudformation | @mock_cloudformation| core endpoints done | | Cloudformation | @mock_cloudformation | core endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Cloudwatch | @mock_cloudwatch | basic endpoints done | | Cloudwatch | @mock_cloudwatch | basic endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| CloudwatchEvents | @mock_events | all endpoints done | | CloudwatchEvents | @mock_events | all endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Cognito Identity | @mock_cognitoidentity| basic endpoints done | | Cognito Identity | @mock_cognitoidentity | basic endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Cognito Identity Provider | @mock_cognitoidp| basic endpoints done | | Cognito Identity Provider | @mock_cognitoidp | basic endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Config | @mock_config | basic endpoints done | | Config | @mock_config | basic endpoints done |
|------------------------------------------------------------------------------| | | | core endpoints done |
| Data Pipeline | @mock_datapipeline| basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Data Pipeline | @mock_datapipeline | basic endpoints done |
| DynamoDB | @mock_dynamodb | core endpoints done | |-------------------------------------------------------------------------------------|
| DynamoDB2 | @mock_dynamodb2 | all endpoints + partial indexes | | DynamoDB | @mock_dynamodb | core endpoints done |
|------------------------------------------------------------------------------| | DynamoDB2 | @mock_dynamodb2 | all endpoints + partial indexes |
| EC2 | @mock_ec2 | core endpoints done | |-------------------------------------------------------------------------------------|
| - AMI | | core endpoints done | | EC2 | @mock_ec2 | core endpoints done |
| - EBS | | core endpoints done | | - AMI | | core endpoints done |
| - Instances | | all endpoints done | | - EBS | | core endpoints done |
| - Security Groups | | core endpoints done | | - Instances | | all endpoints done |
| - Tags | | all endpoints done | | - Security Groups | | core endpoints done |
|------------------------------------------------------------------------------| | - Tags | | all endpoints done |
| ECR | @mock_ecr | basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | ECR | @mock_ecr | basic endpoints done |
| ECS | @mock_ecs | basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | ECS | @mock_ecs | basic endpoints done |
| ELB | @mock_elb | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | ELB | @mock_elb | core endpoints done |
| ELBv2 | @mock_elbv2 | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | ELBv2 | @mock_elbv2 | all endpoints done |
| EMR | @mock_emr | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | EMR | @mock_emr | core endpoints done |
| Glacier | @mock_glacier | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Glacier | @mock_glacier | core endpoints done |
| IAM | @mock_iam | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | IAM | @mock_iam | core endpoints done |
| IoT | @mock_iot | core endpoints done | |-------------------------------------------------------------------------------------|
| | @mock_iotdata | core endpoints done | | IoT | @mock_iot | core endpoints done |
|------------------------------------------------------------------------------| | | @mock_iotdata | core endpoints done |
| Lambda | @mock_lambda | basic endpoints done, requires | |-------------------------------------------------------------------------------------|
| | | docker | | Kinesis | @mock_kinesis | core endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Logs | @mock_logs | basic endpoints done | | KMS | @mock_kms | basic endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Kinesis | @mock_kinesis | core endpoints done | | Lambda | @mock_lambda | basic endpoints done, requires |
|------------------------------------------------------------------------------| | | | docker |
| KMS | @mock_kms | basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Logs | @mock_logs | basic endpoints done |
| Organizations | @mock_organizations | some core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Organizations | @mock_organizations | some core endpoints done |
| Polly | @mock_polly | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Polly | @mock_polly | all endpoints done |
| RDS | @mock_rds | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | RDS | @mock_rds | core endpoints done |
| RDS2 | @mock_rds2 | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | RDS2 | @mock_rds2 | core endpoints done |
| Redshift | @mock_redshift | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Redshift | @mock_redshift | core endpoints done |
| Route53 | @mock_route53 | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Route53 | @mock_route53 | core endpoints done |
| S3 | @mock_s3 | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | S3 | @mock_s3 | core endpoints done |
| SecretsManager | @mock_secretsmanager | basic endpoints done |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SecretsManager | @mock_secretsmanager | basic endpoints done |
| SES | @mock_ses | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SES | @mock_ses | all endpoints done |
| SNS | @mock_sns | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SNS | @mock_sns | all endpoints done |
| SQS | @mock_sqs | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SQS | @mock_sqs | core endpoints done |
| SSM | @mock_ssm | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SSM | @mock_ssm | core endpoints done |
| STS | @mock_sts | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | STS | @mock_sts | core endpoints done |
| SWF | @mock_swf | basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SWF | @mock_swf | basic endpoints done |
| X-Ray | @mock_xray | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | X-Ray | @mock_xray | all endpoints done |
|-------------------------------------------------------------------------------------|
``` ```
For a full list of endpoint [implementation coverage](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) For a full list of endpoint [implementation coverage](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md)
@ -252,6 +256,140 @@ def test_my_model_save():
mock.stop() mock.stop()
``` ```
## IAM-like Access Control
Moto also has the ability to authenticate and authorize actions, just like it's done by IAM in AWS. This functionality can be enabled by either setting the `INITIAL_NO_AUTH_ACTION_COUNT` environment variable or using the `set_initial_no_auth_action_count` decorator. Note that the current implementation is very basic, see [this file](https://github.com/spulec/moto/blob/master/moto/core/access_control.py) for more information.
### `INITIAL_NO_AUTH_ACTION_COUNT`
If this environment variable is set, moto will skip performing any authentication as many times as the variable's value, and only starts authenticating requests afterwards. If it is not set, it defaults to infinity, thus moto will never perform any authentication at all.
### `set_initial_no_auth_action_count`
This is a decorator that works similarly to the environment variable, but the settings are only valid in the function's scope. When the function returns, everything is restored.
```python
@set_initial_no_auth_action_count(4)
@mock_ec2
def test_describe_instances_allowed():
policy_document = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:Describe*",
"Resource": "*"
}
]
}
access_key = ...
# create access key for an IAM user/assumed role that has the policy above.
# this part should call __exactly__ 4 AWS actions, so that authentication and authorization starts exactly after this
client = boto3.client('ec2', region_name='us-east-1',
aws_access_key_id=access_key['AccessKeyId'],
aws_secret_access_key=access_key['SecretAccessKey'])
# if the IAM principal whose access key is used, does not have the permission to describe instances, this will fail
instances = client.describe_instances()['Reservations'][0]['Instances']
assert len(instances) == 0
```
See [the related test suite](https://github.com/spulec/moto/blob/master/tests/test_core/test_auth.py) for more examples.
## Experimental: AWS Config Querying
For details about the experimental AWS Config support please see the [AWS Config readme here](CONFIG_README.md).
## Very Important -- Recommended Usage
There are some important caveats to be aware of when using moto:
*Failure to follow these guidelines could result in your tests mutating your __REAL__ infrastructure!*
### How do I avoid tests from mutating my real infrastructure?
You need to ensure that the mocks are actually in place. Changes made to recent versions of `botocore`
have altered some of the mock behavior. In short, you need to ensure that you _always_ do the following:
1. Ensure that your tests have dummy environment variables set up:
export AWS_ACCESS_KEY_ID='testing'
export AWS_SECRET_ACCESS_KEY='testing'
export AWS_SECURITY_TOKEN='testing'
export AWS_SESSION_TOKEN='testing'
1. __VERY IMPORTANT__: ensure that you have your mocks set up __BEFORE__ your `boto3` client is established.
This can typically happen if you import a module that has a `boto3` client instantiated outside of a function.
See the pesky imports section below on how to work around this.
### Example on usage?
If you are a user of [pytest](https://pytest.org/en/latest/), you can leverage [pytest fixtures](https://pytest.org/en/latest/fixture.html#fixture)
to help set up your mocks and other AWS resources that you would need.
Here is an example:
```python
@pytest.fixture(scope='function')
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'
@pytest.fixture(scope='function')
def s3(aws_credentials):
with mock_s3():
yield boto3.client('s3', region_name='us-east-1')
@pytest.fixture(scope='function')
def sts(aws_credentials):
with mock_sts():
yield boto3.client('sts', region_name='us-east-1')
@pytest.fixture(scope='function')
def cloudwatch(aws_credentials):
with mock_cloudwatch():
yield boto3.client('cloudwatch', region_name='us-east-1')
... etc.
```
In the code sample above, all of the AWS/mocked fixtures take in a parameter of `aws_credentials`,
which sets the proper fake environment variables. The fake environment variables are used so that `botocore` doesn't try to locate real
credentials on your system.
Next, once you need to do anything with the mocked AWS environment, do something like:
```python
def test_create_bucket(s3):
# s3 is a fixture defined above that yields a boto3 s3 client.
# Feel free to instantiate another boto3 S3 client -- Keep note of the region though.
s3.create_bucket(Bucket="somebucket")
result = s3.list_buckets()
assert len(result['Buckets']) == 1
assert result['Buckets'][0]['Name'] == 'somebucket'
```
### What about those pesky imports?
Recall earlier, it was mentioned that mocks should be established __BEFORE__ the clients are set up. One way
to avoid import issues is to make use of local Python imports -- i.e. import the module inside of the unit
test you want to run vs. importing at the top of the file.
Example:
```python
def test_something(s3):
from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test
# ^^ Importing here ensures that the mock has been established.
some_func() # The mock has been established from the "s3" pytest fixture, so this function that uses
# a package-level S3 client will properly use the mock and not reach out to AWS.
```
### Other caveats
For Tox, Travis CI, and other build systems, you might need to also perform a `touch ~/.aws/credentials`
command before running the tests. As long as that file is present (empty preferably) and the environment
variables above are set, you should be good to go.
## Stand-alone Server Mode ## Stand-alone Server Mode
Moto also has a stand-alone server mode. This allows you to utilize Moto also has a stand-alone server mode. This allows you to utilize
@ -318,3 +456,11 @@ boto3.resource(
```console ```console
$ pip install moto $ pip install moto
``` ```
## Releases
Releases are done from travisci. Fairly closely following this:
https://docs.travis-ci.com/user/deployment/pypi/
- Commits to `master` branch do a dev deploy to pypi.
- Commits to a tag do a real deploy to pypi.

View File

@ -30,6 +30,8 @@ Currently implemented Services:
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| Data Pipeline | @mock_datapipeline | basic endpoints done | | Data Pipeline | @mock_datapipeline | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| DataSync | @mock_datasync | some endpoints done |
+-----------------------+---------------------+-----------------------------------+
| - DynamoDB | - @mock_dynamodb | - core endpoints done | | - DynamoDB | - @mock_dynamodb | - core endpoints done |
| - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes| | - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes|
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+

View File

@ -17,66 +17,97 @@ with ``moto`` and its usage.
Currently implemented Services: Currently implemented Services:
------------------------------- -------------------------------
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Service Name | Decorator | Development Status | | Service Name | Decorator | Development Status |
+=======================+=====================+===================================+ +===========================+=======================+====================================+
| API Gateway | @mock_apigateway | core endpoints done | | ACM | @mock_acm | all endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Autoscaling | @mock_autoscaling | core endpoints done | | API Gateway | @mock_apigateway | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Cloudformation | @mock_cloudformation| core endpoints done | | Autoscaling | @mock_autoscaling | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Cloudwatch | @mock_cloudwatch | basic endpoints done | | Cloudformation | @mock_cloudformation | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Data Pipeline | @mock_datapipeline | basic endpoints done | | Cloudwatch | @mock_cloudwatch | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| - DynamoDB | - @mock_dynamodb | - core endpoints done | | CloudwatchEvents | @mock_events | all endpoints done |
| - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes| +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | Cognito Identity | @mock_cognitoidentity | all endpoints done |
| EC2 | @mock_ec2 | core endpoints done | +---------------------------+-----------------------+------------------------------------+
| - AMI | | - core endpoints done | | Cognito Identity Provider | @mock_cognitoidp | all endpoints done |
| - EBS | | - core endpoints done | +---------------------------+-----------------------+------------------------------------+
| - Instances | | - all endpoints done | | Config | @mock_config | basic endpoints done |
| - Security Groups | | - core endpoints done | +---------------------------+-----------------------+------------------------------------+
| - Tags | | - all endpoints done | | Data Pipeline | @mock_datapipeline | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| ECS | @mock_ecs | basic endpoints done | | DynamoDB | - @mock_dynamodb | - core endpoints done |
+-----------------------+---------------------+-----------------------------------+ | DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes |
| ELB | @mock_elb | core endpoints done | +---------------------------+-----------------------+------------------------------------+
| | @mock_elbv2 | core endpoints done | | EC2 | @mock_ec2 | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ | - AMI | | - core endpoints done |
| EMR | @mock_emr | core endpoints done | | - EBS | | - core endpoints done |
+-----------------------+---------------------+-----------------------------------+ | - Instances | | - all endpoints done |
| Glacier | @mock_glacier | core endpoints done | | - Security Groups | | - core endpoints done |
+-----------------------+---------------------+-----------------------------------+ | - Tags | | - all endpoints done |
| IAM | @mock_iam | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | ECR | @mock_ecr | basic endpoints done |
| Lambda | @mock_lambda | basic endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | ECS | @mock_ecs | basic endpoints done |
| Kinesis | @mock_kinesis | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | ELB | @mock_elb | core endpoints done |
| KMS | @mock_kms | basic endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | ELBv2 | @mock_elbv2 | all endpoints done |
| RDS | @mock_rds | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | EMR | @mock_emr | core endpoints done |
| RDS2 | @mock_rds2 | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | Glacier | @mock_glacier | core endpoints done |
| Redshift | @mock_redshift | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | IAM | @mock_iam | core endpoints done |
| Route53 | @mock_route53 | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | IoT | @mock_iot | core endpoints done |
| S3 | @mock_s3 | core endpoints done | | | @mock_iotdata | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SES | @mock_ses | core endpoints done | | Kinesis | @mock_kinesis | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SNS | @mock_sns | core endpoints done | | KMS | @mock_kms | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SQS | @mock_sqs | core endpoints done | | Lambda | @mock_lambda | basic endpoints done, |
+-----------------------+---------------------+-----------------------------------+ | | | requires docker |
| STS | @mock_sts | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | Logs | @mock_logs | basic endpoints done |
| SWF | @mock_swf | basic endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | Organizations | @mock_organizations | some core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Polly | @mock_polly | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| RDS | @mock_rds | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| RDS2 | @mock_rds2 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Redshift | @mock_redshift | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Route53 | @mock_route53 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| S3 | @mock_s3 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SecretsManager | @mock_secretsmanager | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SES | @mock_ses | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SFN | @mock_stepfunctions | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SNS | @mock_sns | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SQS | @mock_sqs | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SSM | @mock_ssm | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| STS | @mock_sts | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SWF | @mock_swf | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| X-Ray | @mock_xray | all endpoints done |
+---------------------------+-----------------------+------------------------------------+

View File

@ -1,61 +1,77 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import logging
from .acm import mock_acm # noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # noqa
from .athena import mock_athena # noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # noqa
from .batch import mock_batch # noqa
from .cloudformation import mock_cloudformation # noqa
from .cloudformation import mock_cloudformation_deprecated # noqa
from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # noqa
from .codecommit import mock_codecommit # noqa
from .codepipeline import mock_codepipeline # noqa
from .cognitoidentity import mock_cognitoidentity # noqa
from .cognitoidentity import mock_cognitoidentity_deprecated # noqa
from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # noqa
from .config import mock_config # noqa
from .datapipeline import mock_datapipeline # noqa
from .datapipeline import mock_datapipeline_deprecated # noqa
from .datasync import mock_datasync # noqa
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # noqa
from .dynamodbstreams import mock_dynamodbstreams # noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # noqa
from .ec2_instance_connect import mock_ec2_instance_connect # noqa
from .ecr import mock_ecr, mock_ecr_deprecated # noqa
from .ecs import mock_ecs, mock_ecs_deprecated # noqa
from .elb import mock_elb, mock_elb_deprecated # noqa
from .elbv2 import mock_elbv2 # noqa
from .emr import mock_emr, mock_emr_deprecated # noqa
from .events import mock_events # noqa
from .glacier import mock_glacier, mock_glacier_deprecated # noqa
from .glue import mock_glue # noqa
from .iam import mock_iam, mock_iam_deprecated # noqa
from .iot import mock_iot # noqa
from .iotdata import mock_iotdata # noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # noqa
from .kms import mock_kms, mock_kms_deprecated # noqa
from .logs import mock_logs, mock_logs_deprecated # noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # noqa
from .organizations import mock_organizations # noqa
from .polly import mock_polly # noqa
from .rds import mock_rds, mock_rds_deprecated # noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # noqa
from .redshift import mock_redshift, mock_redshift_deprecated # noqa
from .resourcegroups import mock_resourcegroups # noqa
from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # noqa
from .route53 import mock_route53, mock_route53_deprecated # noqa
from .s3 import mock_s3, mock_s3_deprecated # noqa
from .secretsmanager import mock_secretsmanager # noqa
from .ses import mock_ses, mock_ses_deprecated # noqa
from .sns import mock_sns, mock_sns_deprecated # noqa
from .sqs import mock_sqs, mock_sqs_deprecated # noqa
from .ssm import mock_ssm # noqa
from .stepfunctions import mock_stepfunctions # noqa
from .sts import mock_sts, mock_sts_deprecated # noqa
from .swf import mock_swf, mock_swf_deprecated # noqa
from .xray import XRaySegment, mock_xray, mock_xray_client # noqa
# import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL) # logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto' __title__ = "moto"
__version__ = '1.3.8' __version__ = "1.3.15.dev"
from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # flake8: noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # flake8: noqa
from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # flake8: noqa
from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # flake8: noqa
from .cognitoidentity import mock_cognitoidentity, mock_cognitoidentity_deprecated # flake8: noqa
from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # flake8: noqa
from .config import mock_config # flake8: noqa
from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # flake8: noqa
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # flake8: noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # flake8: noqa
from .dynamodbstreams import mock_dynamodbstreams # flake8: noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # flake8: noqa
from .ecr import mock_ecr, mock_ecr_deprecated # flake8: noqa
from .ecs import mock_ecs, mock_ecs_deprecated # flake8: noqa
from .elb import mock_elb, mock_elb_deprecated # flake8: noqa
from .elbv2 import mock_elbv2 # flake8: noqa
from .emr import mock_emr, mock_emr_deprecated # flake8: noqa
from .events import mock_events # flake8: noqa
from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa
from .glue import mock_glue # flake8: noqa
from .iam import mock_iam, mock_iam_deprecated # flake8: noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa
from .kms import mock_kms, mock_kms_deprecated # flake8: noqa
from .organizations import mock_organizations # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .polly import mock_polly # flake8: noqa
from .rds import mock_rds, mock_rds_deprecated # flake8: noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa
from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa
from .s3 import mock_s3, mock_s3_deprecated # flake8: noqa
from .ses import mock_ses, mock_ses_deprecated # flake8: noqa
from .secretsmanager import mock_secretsmanager # flake8: noqa
from .sns import mock_sns, mock_sns_deprecated # flake8: noqa
from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa
from .sts import mock_sts, mock_sts_deprecated # flake8: noqa
from .ssm import mock_ssm # flake8: noqa
from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa
from .swf import mock_swf, mock_swf_deprecated # flake8: noqa
from .xray import mock_xray, mock_xray_client, XRaySegment # flake8: noqa
from .logs import mock_logs, mock_logs_deprecated # flake8: noqa
from .batch import mock_batch # flake8: noqa
from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # flake8: noqa
from .iot import mock_iot # flake8: noqa
from .iotdata import mock_iotdata # flake8: noqa
try: try:
# Need to monkey-patch botocore requests back to underlying urllib3 classes # Need to monkey-patch botocore requests back to underlying urllib3 classes
from botocore.awsrequest import HTTPSConnectionPool, HTTPConnectionPool, HTTPConnection, VerifiedHTTPSConnection from botocore.awsrequest import (
HTTPSConnectionPool,
HTTPConnectionPool,
HTTPConnection,
VerifiedHTTPSConnection,
)
except ImportError: except ImportError:
pass pass
else: else:

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import acm_backends from .models import acm_backends
from ..core.models import base_decorator from ..core.models import base_decorator
acm_backend = acm_backends['us-east-1'] acm_backend = acm_backends["us-east-1"]
mock_acm = base_decorator(acm_backends) mock_acm = base_decorator(acm_backends)

View File

@ -13,8 +13,9 @@ import cryptography.hazmat.primitives.asymmetric.rsa
from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
DEFAULT_ACCOUNT_ID = 123456789012
GOOGLE_ROOT_CA = b"""-----BEGIN CERTIFICATE----- GOOGLE_ROOT_CA = b"""-----BEGIN CERTIFICATE-----
MIIEKDCCAxCgAwIBAgIQAQAhJYiw+lmnd+8Fe2Yn3zANBgkqhkiG9w0BAQsFADBC MIIEKDCCAxCgAwIBAgIQAQAhJYiw+lmnd+8Fe2Yn3zANBgkqhkiG9w0BAQsFADBC
MQswCQYDVQQGEwJVUzEWMBQGA1UEChMNR2VvVHJ1c3QgSW5jLjEbMBkGA1UEAxMS MQswCQYDVQQGEwJVUzEWMBQGA1UEChMNR2VvVHJ1c3QgSW5jLjEbMBkGA1UEAxMS
@ -57,20 +58,29 @@ class AWSError(Exception):
self.message = message self.message = message
def response(self): def response(self):
resp = {'__type': self.TYPE, 'message': self.message} resp = {"__type": self.TYPE, "message": self.message}
return json.dumps(resp), dict(status=self.STATUS) return json.dumps(resp), dict(status=self.STATUS)
class AWSValidationException(AWSError): class AWSValidationException(AWSError):
TYPE = 'ValidationException' TYPE = "ValidationException"
class AWSResourceNotFoundException(AWSError): class AWSResourceNotFoundException(AWSError):
TYPE = 'ResourceNotFoundException' TYPE = "ResourceNotFoundException"
class CertBundle(BaseModel): class CertBundle(BaseModel):
def __init__(self, certificate, private_key, chain=None, region='us-east-1', arn=None, cert_type='IMPORTED', cert_status='ISSUED'): def __init__(
self,
certificate,
private_key,
chain=None,
region="us-east-1",
arn=None,
cert_type="IMPORTED",
cert_status="ISSUED",
):
self.created_at = datetime.datetime.now() self.created_at = datetime.datetime.now()
self.cert = certificate self.cert = certificate
self._cert = None self._cert = None
@ -87,7 +97,7 @@ class CertBundle(BaseModel):
if self.chain is None: if self.chain is None:
self.chain = GOOGLE_ROOT_CA self.chain = GOOGLE_ROOT_CA
else: else:
self.chain += b'\n' + GOOGLE_ROOT_CA self.chain += b"\n" + GOOGLE_ROOT_CA
# Takes care of PEM checking # Takes care of PEM checking
self.validate_pk() self.validate_pk()
@ -105,7 +115,7 @@ class CertBundle(BaseModel):
self.arn = arn self.arn = arn
@classmethod @classmethod
def generate_cert(cls, domain_name, sans=None): def generate_cert(cls, domain_name, region, sans=None):
if sans is None: if sans is None:
sans = set() sans = set()
else: else:
@ -114,149 +124,209 @@ class CertBundle(BaseModel):
sans.add(domain_name) sans.add(domain_name)
sans = [cryptography.x509.DNSName(item) for item in sans] sans = [cryptography.x509.DNSName(item) for item in sans]
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
subject = cryptography.x509.Name([ public_exponent=65537, key_size=2048, backend=default_backend()
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"), )
cryptography.x509.NameAttribute(cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, u"CA"), subject = cryptography.x509.Name(
cryptography.x509.NameAttribute(cryptography.x509.NameOID.LOCALITY_NAME, u"San Francisco"), [
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"My Company"), cryptography.x509.NameAttribute(
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, domain_name), cryptography.x509.NameOID.COUNTRY_NAME, "US"
]) ),
issuer = cryptography.x509.Name([ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon cryptography.x509.NameAttribute(
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"), cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA"
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"Amazon"), ),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u"Server CA 1B"), cryptography.x509.NameAttribute(
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, u"Amazon"), cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco"
]) ),
cert = cryptography.x509.CertificateBuilder().subject_name( cryptography.x509.NameAttribute(
subject cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company"
).issuer_name( ),
issuer cryptography.x509.NameAttribute(
).public_key( cryptography.x509.NameOID.COMMON_NAME, domain_name
key.public_key() ),
).serial_number( ]
cryptography.x509.random_serial_number() )
).not_valid_before( issuer = cryptography.x509.Name(
datetime.datetime.utcnow() [ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
).not_valid_after( cryptography.x509.NameAttribute(
datetime.datetime.utcnow() + datetime.timedelta(days=365) cryptography.x509.NameOID.COUNTRY_NAME, "US"
).add_extension( ),
cryptography.x509.SubjectAlternativeName(sans), cryptography.x509.NameAttribute(
critical=False, cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon"
).sign(key, hashes.SHA512(), default_backend()) ),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, "Amazon"
),
]
)
cert = (
cryptography.x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(cryptography.x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.add_extension(
cryptography.x509.SubjectAlternativeName(sans), critical=False
)
.sign(key, hashes.SHA512(), default_backend())
)
cert_armored = cert.public_bytes(serialization.Encoding.PEM) cert_armored = cert.public_bytes(serialization.Encoding.PEM)
private_key = key.private_bytes( private_key = key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL, format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption() encryption_algorithm=serialization.NoEncryption(),
) )
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION') return cls(
cert_armored,
private_key,
cert_type="AMAZON_ISSUED",
cert_status="PENDING_VALIDATION",
region=region,
)
def validate_pk(self): def validate_pk(self):
try: try:
self._key = serialization.load_pem_private_key(self.key, password=None, backend=default_backend()) self._key = serialization.load_pem_private_key(
self.key, password=None, backend=default_backend()
)
if self._key.key_size > 2048: if self._key.key_size > 2048:
AWSValidationException('The private key length is not supported. Only 1024-bit and 2048-bit are allowed.') AWSValidationException(
"The private key length is not supported. Only 1024-bit and 2048-bit are allowed."
)
except Exception as err: except Exception as err:
if isinstance(err, AWSValidationException): if isinstance(err, AWSValidationException):
raise raise
raise AWSValidationException('The private key is not PEM-encoded or is not valid.') raise AWSValidationException(
"The private key is not PEM-encoded or is not valid."
)
def validate_certificate(self): def validate_certificate(self):
try: try:
self._cert = cryptography.x509.load_pem_x509_certificate(self.cert, default_backend()) self._cert = cryptography.x509.load_pem_x509_certificate(
self.cert, default_backend()
)
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
if self._cert.not_valid_after < now: if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate has expired, is not valid.') raise AWSValidationException(
"The certificate has expired, is not valid."
)
if self._cert.not_valid_before > now: if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate is not in effect yet, is not valid.') raise AWSValidationException(
"The certificate is not in effect yet, is not valid."
)
# Extracting some common fields for ease of use # Extracting some common fields for ease of use
# Have to search through cert.subject for OIDs # Have to search through cert.subject for OIDs
self.common_name = self._cert.subject.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value self.common_name = self._cert.subject.get_attributes_for_oid(
cryptography.x509.OID_COMMON_NAME
)[0].value
except Exception as err: except Exception as err:
if isinstance(err, AWSValidationException): if isinstance(err, AWSValidationException):
raise raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.') raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
)
def validate_chain(self): def validate_chain(self):
try: try:
self._chain = [] self._chain = []
for cert_armored in self.chain.split(b'-\n-'): for cert_armored in self.chain.split(b"-\n-"):
# Would leave encoded but Py2 does not have raw binary strings # Would leave encoded but Py2 does not have raw binary strings
cert_armored = cert_armored.decode() cert_armored = cert_armored.decode()
# Fix missing -'s on split # Fix missing -'s on split
cert_armored = re.sub(r'^----B', '-----B', cert_armored) cert_armored = re.sub(r"^----B", "-----B", cert_armored)
cert_armored = re.sub(r'E----$', 'E-----', cert_armored) cert_armored = re.sub(r"E----$", "E-----", cert_armored)
cert = cryptography.x509.load_pem_x509_certificate(cert_armored.encode(), default_backend()) cert = cryptography.x509.load_pem_x509_certificate(
cert_armored.encode(), default_backend()
)
self._chain.append(cert) self._chain.append(cert)
now = datetime.datetime.now() now = datetime.datetime.now()
if self._cert.not_valid_after < now: if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate chain has expired, is not valid.') raise AWSValidationException(
"The certificate chain has expired, is not valid."
)
if self._cert.not_valid_before > now: if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate chain is not in effect yet, is not valid.') raise AWSValidationException(
"The certificate chain is not in effect yet, is not valid."
)
except Exception as err: except Exception as err:
if isinstance(err, AWSValidationException): if isinstance(err, AWSValidationException):
raise raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.') raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
)
def check(self): def check(self):
# Basically, if the certificate is pending, and then checked again after 1 min # Basically, if the certificate is pending, and then checked again after 1 min
# It will appear as if its been validated # It will appear as if its been validated
if self.type == 'AMAZON_ISSUED' and self.status == 'PENDING_VALIDATION' and \ if (
(datetime.datetime.now() - self.created_at).total_seconds() > 60: # 1min self.type == "AMAZON_ISSUED"
self.status = 'ISSUED' and self.status == "PENDING_VALIDATION"
and (datetime.datetime.now() - self.created_at).total_seconds() > 60
): # 1min
self.status = "ISSUED"
def describe(self): def describe(self):
# 'RenewalSummary': {}, # Only when cert is amazon issued # 'RenewalSummary': {}, # Only when cert is amazon issued
if self._key.key_size == 1024: if self._key.key_size == 1024:
key_algo = 'RSA_1024' key_algo = "RSA_1024"
elif self._key.key_size == 2048: elif self._key.key_size == 2048:
key_algo = 'RSA_2048' key_algo = "RSA_2048"
else: else:
key_algo = 'EC_prime256v1' key_algo = "EC_prime256v1"
# Look for SANs # Look for SANs
san_obj = self._cert.extensions.get_extension_for_oid(cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME) san_obj = self._cert.extensions.get_extension_for_oid(
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
)
sans = [] sans = []
if san_obj is not None: if san_obj is not None:
sans = [item.value for item in san_obj.value] sans = [item.value for item in san_obj.value]
result = { result = {
'Certificate': { "Certificate": {
'CertificateArn': self.arn, "CertificateArn": self.arn,
'DomainName': self.common_name, "DomainName": self.common_name,
'InUseBy': [], "InUseBy": [],
'Issuer': self._cert.issuer.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value, "Issuer": self._cert.issuer.get_attributes_for_oid(
'KeyAlgorithm': key_algo, cryptography.x509.OID_COMMON_NAME
'NotAfter': datetime_to_epoch(self._cert.not_valid_after), )[0].value,
'NotBefore': datetime_to_epoch(self._cert.not_valid_before), "KeyAlgorithm": key_algo,
'Serial': self._cert.serial_number, "NotAfter": datetime_to_epoch(self._cert.not_valid_after),
'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''), "NotBefore": datetime_to_epoch(self._cert.not_valid_before),
'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED. "Serial": self._cert.serial_number,
'Subject': 'CN={0}'.format(self.common_name), "SignatureAlgorithm": self._cert.signature_algorithm_oid._name.upper().replace(
'SubjectAlternativeNames': sans, "ENCRYPTION", ""
'Type': self.type # One of IMPORTED, AMAZON_ISSUED ),
"Status": self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED.
"Subject": "CN={0}".format(self.common_name),
"SubjectAlternativeNames": sans,
"Type": self.type, # One of IMPORTED, AMAZON_ISSUED
} }
} }
if self.type == 'IMPORTED': if self.type == "IMPORTED":
result['Certificate']['ImportedAt'] = datetime_to_epoch(self.created_at) result["Certificate"]["ImportedAt"] = datetime_to_epoch(self.created_at)
else: else:
result['Certificate']['CreatedAt'] = datetime_to_epoch(self.created_at) result["Certificate"]["CreatedAt"] = datetime_to_epoch(self.created_at)
result['Certificate']['IssuedAt'] = datetime_to_epoch(self.created_at) result["Certificate"]["IssuedAt"] = datetime_to_epoch(self.created_at)
return result return result
@ -264,7 +334,7 @@ class CertBundle(BaseModel):
return self.arn return self.arn
def __repr__(self): def __repr__(self):
return '<Certificate>' return "<Certificate>"
class AWSCertificateManagerBackend(BaseBackend): class AWSCertificateManagerBackend(BaseBackend):
@ -281,7 +351,9 @@ class AWSCertificateManagerBackend(BaseBackend):
@staticmethod @staticmethod
def _arn_not_found(arn): def _arn_not_found(arn):
msg = 'Certificate with arn {0} not found in account {1}'.format(arn, DEFAULT_ACCOUNT_ID) msg = "Certificate with arn {0} not found in account {1}".format(
arn, DEFAULT_ACCOUNT_ID
)
return AWSResourceNotFoundException(msg) return AWSResourceNotFoundException(msg)
def _get_arn_from_idempotency_token(self, token): def _get_arn_from_idempotency_token(self, token):
@ -298,17 +370,20 @@ class AWSCertificateManagerBackend(BaseBackend):
""" """
now = datetime.datetime.now() now = datetime.datetime.now()
if token in self._idempotency_tokens: if token in self._idempotency_tokens:
if self._idempotency_tokens[token]['expires'] < now: if self._idempotency_tokens[token]["expires"] < now:
# Token has expired, new request # Token has expired, new request
del self._idempotency_tokens[token] del self._idempotency_tokens[token]
return None return None
else: else:
return self._idempotency_tokens[token]['arn'] return self._idempotency_tokens[token]["arn"]
return None return None
def _set_idempotency_token_arn(self, token, arn): def _set_idempotency_token_arn(self, token, arn):
self._idempotency_tokens[token] = {'arn': arn, 'expires': datetime.datetime.now() + datetime.timedelta(hours=1)} self._idempotency_tokens[token] = {
"arn": arn,
"expires": datetime.datetime.now() + datetime.timedelta(hours=1),
}
def import_cert(self, certificate, private_key, chain=None, arn=None): def import_cert(self, certificate, private_key, chain=None, arn=None):
if arn is not None: if arn is not None:
@ -316,7 +391,9 @@ class AWSCertificateManagerBackend(BaseBackend):
raise self._arn_not_found(arn) raise self._arn_not_found(arn)
else: else:
# Will reuse provided ARN # Will reuse provided ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region, arn=arn) bundle = CertBundle(
certificate, private_key, chain=chain, region=region, arn=arn
)
else: else:
# Will generate a random ARN # Will generate a random ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region) bundle = CertBundle(certificate, private_key, chain=chain, region=region)
@ -325,7 +402,7 @@ class AWSCertificateManagerBackend(BaseBackend):
return bundle.arn return bundle.arn
def get_certificates_list(self): def get_certificates_list(self, statuses):
""" """
Get list of certificates Get list of certificates
@ -333,7 +410,9 @@ class AWSCertificateManagerBackend(BaseBackend):
:rtype: list of CertBundle :rtype: list of CertBundle
""" """
for arn in self._certificates.keys(): for arn in self._certificates.keys():
yield self.get_certificate(arn) cert = self.get_certificate(arn)
if not statuses or cert.status in statuses:
yield cert
def get_certificate(self, arn): def get_certificate(self, arn):
if arn not in self._certificates: if arn not in self._certificates:
@ -349,13 +428,21 @@ class AWSCertificateManagerBackend(BaseBackend):
del self._certificates[arn] del self._certificates[arn]
def request_certificate(self, domain_name, domain_validation_options, idempotency_token, subject_alt_names): def request_certificate(
self,
domain_name,
domain_validation_options,
idempotency_token,
subject_alt_names,
):
if idempotency_token is not None: if idempotency_token is not None:
arn = self._get_arn_from_idempotency_token(idempotency_token) arn = self._get_arn_from_idempotency_token(idempotency_token)
if arn is not None: if arn is not None:
return arn return arn
cert = CertBundle.generate_cert(domain_name, subject_alt_names) cert = CertBundle.generate_cert(
domain_name, region=self.region, sans=subject_alt_names
)
if idempotency_token is not None: if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn) self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert self._certificates[cert.arn] = cert
@ -367,8 +454,8 @@ class AWSCertificateManagerBackend(BaseBackend):
cert_bundle = self.get_certificate(arn) cert_bundle = self.get_certificate(arn)
for tag in tags: for tag in tags:
key = tag['Key'] key = tag["Key"]
value = tag.get('Value', None) value = tag.get("Value", None)
cert_bundle.tags[key] = value cert_bundle.tags[key] = value
def remove_tags_from_certificate(self, arn, tags): def remove_tags_from_certificate(self, arn, tags):
@ -376,8 +463,8 @@ class AWSCertificateManagerBackend(BaseBackend):
cert_bundle = self.get_certificate(arn) cert_bundle = self.get_certificate(arn)
for tag in tags: for tag in tags:
key = tag['Key'] key = tag["Key"]
value = tag.get('Value', None) value = tag.get("Value", None)
try: try:
# If value isnt provided, just delete key # If value isnt provided, just delete key

View File

@ -7,7 +7,6 @@ from .models import acm_backends, AWSError, AWSValidationException
class AWSCertificateManagerResponse(BaseResponse): class AWSCertificateManagerResponse(BaseResponse):
@property @property
def acm_backend(self): def acm_backend(self):
""" """
@ -29,40 +28,49 @@ class AWSCertificateManagerResponse(BaseResponse):
return self.request_params.get(param, default) return self.request_params.get(param, default)
def add_tags_to_certificate(self): def add_tags_to_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
tags = self._get_param('Tags') tags = self._get_param("Tags")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
self.acm_backend.add_tags_to_certificate(arn, tags) self.acm_backend.add_tags_to_certificate(arn, tags)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
def delete_certificate(self): def delete_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
self.acm_backend.delete_certificate(arn) self.acm_backend.delete_certificate(arn)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
def describe_certificate(self): def describe_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
@ -72,11 +80,14 @@ class AWSCertificateManagerResponse(BaseResponse):
return json.dumps(cert_bundle.describe()) return json.dumps(cert_bundle.describe())
def get_certificate(self): def get_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
@ -84,8 +95,8 @@ class AWSCertificateManagerResponse(BaseResponse):
return err.response() return err.response()
result = { result = {
'Certificate': cert_bundle.cert.decode(), "Certificate": cert_bundle.cert.decode(),
'CertificateChain': cert_bundle.chain.decode() "CertificateChain": cert_bundle.chain.decode(),
} }
return json.dumps(result) return json.dumps(result)
@ -102,104 +113,129 @@ class AWSCertificateManagerResponse(BaseResponse):
:return: str(JSON) for response :return: str(JSON) for response
""" """
certificate = self._get_param('Certificate') certificate = self._get_param("Certificate")
private_key = self._get_param('PrivateKey') private_key = self._get_param("PrivateKey")
chain = self._get_param('CertificateChain') # Optional chain = self._get_param("CertificateChain") # Optional
current_arn = self._get_param('CertificateArn') # Optional current_arn = self._get_param("CertificateArn") # Optional
# Simple parameter decoding. Rather do it here as its a data transport decision not part of the # Simple parameter decoding. Rather do it here as its a data transport decision not part of the
# actual data # actual data
try: try:
certificate = base64.standard_b64decode(certificate) certificate = base64.standard_b64decode(certificate)
except Exception: except Exception:
return AWSValidationException('The certificate is not PEM-encoded or is not valid.').response() return AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
).response()
try: try:
private_key = base64.standard_b64decode(private_key) private_key = base64.standard_b64decode(private_key)
except Exception: except Exception:
return AWSValidationException('The private key is not PEM-encoded or is not valid.').response() return AWSValidationException(
"The private key is not PEM-encoded or is not valid."
).response()
if chain is not None: if chain is not None:
try: try:
chain = base64.standard_b64decode(chain) chain = base64.standard_b64decode(chain)
except Exception: except Exception:
return AWSValidationException('The certificate chain is not PEM-encoded or is not valid.').response() return AWSValidationException(
"The certificate chain is not PEM-encoded or is not valid."
).response()
try: try:
arn = self.acm_backend.import_cert(certificate, private_key, chain=chain, arn=current_arn) arn = self.acm_backend.import_cert(
certificate, private_key, chain=chain, arn=current_arn
)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return json.dumps({'CertificateArn': arn}) return json.dumps({"CertificateArn": arn})
def list_certificates(self): def list_certificates(self):
certs = [] certs = []
statuses = self._get_param("CertificateStatuses")
for cert_bundle in self.acm_backend.get_certificates_list(statuses):
certs.append(
{
"CertificateArn": cert_bundle.arn,
"DomainName": cert_bundle.common_name,
}
)
for cert_bundle in self.acm_backend.get_certificates_list(): result = {"CertificateSummaryList": certs}
certs.append({
'CertificateArn': cert_bundle.arn,
'DomainName': cert_bundle.common_name
})
result = {'CertificateSummaryList': certs}
return json.dumps(result) return json.dumps(result)
def list_tags_for_certificate(self): def list_tags_for_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return {'__type': 'MissingParameter', 'message': msg}, dict(status=400) return {"__type": "MissingParameter", "message": msg}, dict(status=400)
try: try:
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = {'Tags': []} result = {"Tags": []}
# Tag "objects" can not contain the Value part # Tag "objects" can not contain the Value part
for key, value in cert_bundle.tags.items(): for key, value in cert_bundle.tags.items():
tag_dict = {'Key': key} tag_dict = {"Key": key}
if value is not None: if value is not None:
tag_dict['Value'] = value tag_dict["Value"] = value
result['Tags'].append(tag_dict) result["Tags"].append(tag_dict)
return json.dumps(result) return json.dumps(result)
def remove_tags_from_certificate(self): def remove_tags_from_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
tags = self._get_param('Tags') tags = self._get_param("Tags")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
self.acm_backend.remove_tags_from_certificate(arn, tags) self.acm_backend.remove_tags_from_certificate(arn, tags)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
def request_certificate(self): def request_certificate(self):
domain_name = self._get_param('DomainName') domain_name = self._get_param("DomainName")
domain_validation_options = self._get_param('DomainValidationOptions') # is ignored atm domain_validation_options = self._get_param(
idempotency_token = self._get_param('IdempotencyToken') "DomainValidationOptions"
subject_alt_names = self._get_param('SubjectAlternativeNames') ) # is ignored atm
idempotency_token = self._get_param("IdempotencyToken")
subject_alt_names = self._get_param("SubjectAlternativeNames")
if subject_alt_names is not None and len(subject_alt_names) > 10: if subject_alt_names is not None and len(subject_alt_names) > 10:
# There is initial AWS limit of 10 # There is initial AWS limit of 10
msg = 'An ACM limit has been exceeded. Need to request SAN limit to be raised' msg = (
return json.dumps({'__type': 'LimitExceededException', 'message': msg}), dict(status=400) "An ACM limit has been exceeded. Need to request SAN limit to be raised"
)
return (
json.dumps({"__type": "LimitExceededException", "message": msg}),
dict(status=400),
)
try: try:
arn = self.acm_backend.request_certificate(domain_name, domain_validation_options, idempotency_token, subject_alt_names) arn = self.acm_backend.request_certificate(
domain_name,
domain_validation_options,
idempotency_token,
subject_alt_names,
)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return json.dumps({'CertificateArn': arn}) return json.dumps({"CertificateArn": arn})
def resend_validation_email(self): def resend_validation_email(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
domain = self._get_param('Domain') domain = self._get_param("Domain")
# ValidationDomain not used yet. # ValidationDomain not used yet.
# Contains domain which is equal to or a subset of Domain # Contains domain which is equal to or a subset of Domain
# that AWS will send validation emails to # that AWS will send validation emails to
@ -207,18 +243,21 @@ class AWSCertificateManagerResponse(BaseResponse):
# validation_domain = self._get_param('ValidationDomain') # validation_domain = self._get_param('ValidationDomain')
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
if cert_bundle.common_name != domain: if cert_bundle.common_name != domain:
msg = 'Parameter Domain does not match certificate domain' msg = "Parameter Domain does not match certificate domain"
_type = 'InvalidDomainValidationOptionsException' _type = "InvalidDomainValidationOptionsException"
return json.dumps({'__type': _type, 'message': msg}), dict(status=400) return json.dumps({"__type": _type, "message": msg}), dict(status=400)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import AWSCertificateManagerResponse from .responses import AWSCertificateManagerResponse
url_bases = [ url_bases = ["https?://acm.(.+).amazonaws.com"]
"https?://acm.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": AWSCertificateManagerResponse.dispatch}
'{0}/$': AWSCertificateManagerResponse.dispatch,
}

View File

@ -4,4 +4,6 @@ import uuid
def make_arn_for_certificate(account_id, region_name): def make_arn_for_certificate(account_id, region_name):
# Example # Example
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b # arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(region_name, account_id, uuid.uuid4()) return "arn:aws:acm:{0}:{1}:certificate/{2}".format(
region_name, account_id, uuid.uuid4()
)

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import apigateway_backends from .models import apigateway_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
apigateway_backend = apigateway_backends['us-east-1'] apigateway_backend = apigateway_backends["us-east-1"]
mock_apigateway = base_decorator(apigateway_backends) mock_apigateway = base_decorator(apigateway_backends)
mock_apigateway_deprecated = deprecated_base_decorator(apigateway_backends) mock_apigateway_deprecated = deprecated_base_decorator(apigateway_backends)

View File

@ -2,12 +2,96 @@ from __future__ import unicode_literals
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
class BadRequestException(RESTError):
pass
class AwsProxyNotAllowed(BadRequestException):
def __init__(self):
super(AwsProxyNotAllowed, self).__init__(
"BadRequestException",
"Integrations of type 'AWS_PROXY' currently only supports Lambda function and Firehose stream invocations.",
)
class CrossAccountNotAllowed(RESTError):
def __init__(self):
super(CrossAccountNotAllowed, self).__init__(
"AccessDeniedException", "Cross-account pass role is not allowed."
)
class RoleNotSpecified(BadRequestException):
def __init__(self):
super(RoleNotSpecified, self).__init__(
"BadRequestException", "Role ARN must be specified for AWS integrations"
)
class IntegrationMethodNotDefined(BadRequestException):
def __init__(self):
super(IntegrationMethodNotDefined, self).__init__(
"BadRequestException", "Enumeration value for HttpMethod must be non-empty"
)
class InvalidResourcePathException(BadRequestException):
def __init__(self):
super(InvalidResourcePathException, self).__init__(
"BadRequestException",
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end.",
)
class InvalidHttpEndpoint(BadRequestException):
def __init__(self):
super(InvalidHttpEndpoint, self).__init__(
"BadRequestException", "Invalid HTTP endpoint specified for URI"
)
class InvalidArn(BadRequestException):
def __init__(self):
super(InvalidArn, self).__init__(
"BadRequestException", "Invalid ARN specified in the request"
)
class InvalidIntegrationArn(BadRequestException):
def __init__(self):
super(InvalidIntegrationArn, self).__init__(
"BadRequestException", "AWS ARN for integration must contain path or action"
)
class InvalidRequestInput(BadRequestException):
def __init__(self):
super(InvalidRequestInput, self).__init__(
"BadRequestException", "Invalid request input"
)
class NoIntegrationDefined(BadRequestException):
def __init__(self):
super(NoIntegrationDefined, self).__init__(
"BadRequestException", "No integration defined for method"
)
class NoMethodDefined(BadRequestException):
def __init__(self):
super(NoMethodDefined, self).__init__(
"BadRequestException", "The REST API doesn't contain any methods"
)
class StageNotFoundException(RESTError): class StageNotFoundException(RESTError):
code = 404 code = 404
def __init__(self): def __init__(self):
super(StageNotFoundException, self).__init__( super(StageNotFoundException, self).__init__(
"NotFoundException", "Invalid stage identifier specified") "NotFoundException", "Invalid stage identifier specified"
)
class ApiKeyNotFoundException(RESTError): class ApiKeyNotFoundException(RESTError):
@ -15,4 +99,14 @@ class ApiKeyNotFoundException(RESTError):
def __init__(self): def __init__(self):
super(ApiKeyNotFoundException, self).__init__( super(ApiKeyNotFoundException, self).__init__(
"NotFoundException", "Invalid API Key identifier specified") "NotFoundException", "Invalid API Key identifier specified"
)
class ApiKeyAlreadyExists(RESTError):
code = 409
def __init__(self):
super(ApiKeyAlreadyExists, self).__init__(
"ConflictException", "API Key already exists"
)

View File

@ -3,53 +3,69 @@ from __future__ import unicode_literals
import random import random
import string import string
import re
import requests import requests
import time import time
from boto3.session import Session from boto3.session import Session
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
import responses import responses
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from .utils import create_id from .utils import create_id
from moto.core.utils import path_url from moto.core.utils import path_url
from .exceptions import StageNotFoundException, ApiKeyNotFoundException from moto.sts.models import ACCOUNT_ID
from .exceptions import (
ApiKeyNotFoundException,
AwsProxyNotAllowed,
CrossAccountNotAllowed,
IntegrationMethodNotDefined,
InvalidArn,
InvalidIntegrationArn,
InvalidHttpEndpoint,
InvalidResourcePathException,
InvalidRequestInput,
StageNotFoundException,
RoleNotSpecified,
NoIntegrationDefined,
NoMethodDefined,
ApiKeyAlreadyExists,
)
STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}" STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
class Deployment(BaseModel, dict): class Deployment(BaseModel, dict):
def __init__(self, deployment_id, name, description=""): def __init__(self, deployment_id, name, description=""):
super(Deployment, self).__init__() super(Deployment, self).__init__()
self['id'] = deployment_id self["id"] = deployment_id
self['stageName'] = name self["stageName"] = name
self['description'] = description self["description"] = description
self['createdDate'] = int(time.time()) self["createdDate"] = int(time.time())
class IntegrationResponse(BaseModel, dict): class IntegrationResponse(BaseModel, dict):
def __init__(self, status_code, selection_pattern=None): def __init__(self, status_code, selection_pattern=None):
self['responseTemplates'] = {"application/json": None} self["responseTemplates"] = {"application/json": None}
self['statusCode'] = status_code self["statusCode"] = status_code
if selection_pattern: if selection_pattern:
self['selectionPattern'] = selection_pattern self["selectionPattern"] = selection_pattern
class Integration(BaseModel, dict): class Integration(BaseModel, dict):
def __init__(self, integration_type, uri, http_method, request_templates=None): def __init__(self, integration_type, uri, http_method, request_templates=None):
super(Integration, self).__init__() super(Integration, self).__init__()
self['type'] = integration_type self["type"] = integration_type
self['uri'] = uri self["uri"] = uri
self['httpMethod'] = http_method self["httpMethod"] = http_method
self['requestTemplates'] = request_templates self["requestTemplates"] = request_templates
self["integrationResponses"] = { self["integrationResponses"] = {"200": IntegrationResponse(200)}
"200": IntegrationResponse(200)
}
def create_integration_response(self, status_code, selection_pattern): def create_integration_response(self, status_code, selection_pattern):
integration_response = IntegrationResponse( integration_response = IntegrationResponse(status_code, selection_pattern)
status_code, selection_pattern)
self["integrationResponses"][status_code] = integration_response self["integrationResponses"][status_code] = integration_response
return integration_response return integration_response
@ -61,25 +77,25 @@ class Integration(BaseModel, dict):
class MethodResponse(BaseModel, dict): class MethodResponse(BaseModel, dict):
def __init__(self, status_code): def __init__(self, status_code):
super(MethodResponse, self).__init__() super(MethodResponse, self).__init__()
self['statusCode'] = status_code self["statusCode"] = status_code
class Method(BaseModel, dict): class Method(BaseModel, dict):
def __init__(self, method_type, authorization_type): def __init__(self, method_type, authorization_type):
super(Method, self).__init__() super(Method, self).__init__()
self.update(dict( self.update(
httpMethod=method_type, dict(
authorizationType=authorization_type, httpMethod=method_type,
authorizerId=None, authorizationType=authorization_type,
apiKeyRequired=None, authorizerId=None,
requestParameters=None, apiKeyRequired=None,
requestModels=None, requestParameters=None,
methodIntegration=None, requestModels=None,
)) methodIntegration=None,
)
)
self.method_responses = {} self.method_responses = {}
def create_response(self, response_code): def create_response(self, response_code):
@ -95,16 +111,13 @@ class Method(BaseModel, dict):
class Resource(BaseModel): class Resource(BaseModel):
def __init__(self, id, region_name, api_id, path_part, parent_id): def __init__(self, id, region_name, api_id, path_part, parent_id):
self.id = id self.id = id
self.region_name = region_name self.region_name = region_name
self.api_id = api_id self.api_id = api_id
self.path_part = path_part self.path_part = path_part
self.parent_id = parent_id self.parent_id = parent_id
self.resource_methods = { self.resource_methods = {"GET": {}}
'GET': {}
}
def to_dict(self): def to_dict(self):
response = { response = {
@ -113,8 +126,8 @@ class Resource(BaseModel):
"resourceMethods": self.resource_methods, "resourceMethods": self.resource_methods,
} }
if self.parent_id: if self.parent_id:
response['parentId'] = self.parent_id response["parentId"] = self.parent_id
response['pathPart'] = self.path_part response["pathPart"] = self.path_part
return response return response
def get_path(self): def get_path(self):
@ -125,102 +138,112 @@ class Resource(BaseModel):
backend = apigateway_backends[self.region_name] backend = apigateway_backends[self.region_name]
parent = backend.get_resource(self.api_id, self.parent_id) parent = backend.get_resource(self.api_id, self.parent_id)
parent_path = parent.get_path() parent_path = parent.get_path()
if parent_path != '/': # Root parent if parent_path != "/": # Root parent
parent_path += '/' parent_path += "/"
return parent_path return parent_path
else: else:
return '' return ""
def get_response(self, request): def get_response(self, request):
integration = self.get_integration(request.method) integration = self.get_integration(request.method)
integration_type = integration['type'] integration_type = integration["type"]
if integration_type == 'HTTP': if integration_type == "HTTP":
uri = integration['uri'] uri = integration["uri"]
requests_func = getattr(requests, integration[ requests_func = getattr(requests, integration["httpMethod"].lower())
'httpMethod'].lower())
response = requests_func(uri) response = requests_func(uri)
else: else:
raise NotImplementedError( raise NotImplementedError(
"The {0} type has not been implemented".format(integration_type)) "The {0} type has not been implemented".format(integration_type)
)
return response.status_code, response.text return response.status_code, response.text
def add_method(self, method_type, authorization_type): def add_method(self, method_type, authorization_type):
method = Method(method_type=method_type, method = Method(method_type=method_type, authorization_type=authorization_type)
authorization_type=authorization_type)
self.resource_methods[method_type] = method self.resource_methods[method_type] = method
return method return method
def get_method(self, method_type): def get_method(self, method_type):
return self.resource_methods[method_type] return self.resource_methods[method_type]
def add_integration(self, method_type, integration_type, uri, request_templates=None): def add_integration(
self, method_type, integration_type, uri, request_templates=None
):
integration = Integration( integration = Integration(
integration_type, uri, method_type, request_templates=request_templates) integration_type, uri, method_type, request_templates=request_templates
self.resource_methods[method_type]['methodIntegration'] = integration )
self.resource_methods[method_type]["methodIntegration"] = integration
return integration return integration
def get_integration(self, method_type): def get_integration(self, method_type):
return self.resource_methods[method_type]['methodIntegration'] return self.resource_methods[method_type]["methodIntegration"]
def delete_integration(self, method_type): def delete_integration(self, method_type):
return self.resource_methods[method_type].pop('methodIntegration') return self.resource_methods[method_type].pop("methodIntegration")
class Stage(BaseModel, dict): class Stage(BaseModel, dict):
def __init__(
def __init__(self, name=None, deployment_id=None, variables=None, self,
description='', cacheClusterEnabled=False, cacheClusterSize=None): name=None,
deployment_id=None,
variables=None,
description="",
cacheClusterEnabled=False,
cacheClusterSize=None,
):
super(Stage, self).__init__() super(Stage, self).__init__()
if variables is None: if variables is None:
variables = {} variables = {}
self['stageName'] = name self["stageName"] = name
self['deploymentId'] = deployment_id self["deploymentId"] = deployment_id
self['methodSettings'] = {} self["methodSettings"] = {}
self['variables'] = variables self["variables"] = variables
self['description'] = description self["description"] = description
self['cacheClusterEnabled'] = cacheClusterEnabled self["cacheClusterEnabled"] = cacheClusterEnabled
if self['cacheClusterEnabled']: if self["cacheClusterEnabled"]:
self['cacheClusterSize'] = str(0.5) self["cacheClusterSize"] = str(0.5)
if cacheClusterSize is not None: if cacheClusterSize is not None:
self['cacheClusterSize'] = str(cacheClusterSize) self["cacheClusterSize"] = str(cacheClusterSize)
def apply_operations(self, patch_operations): def apply_operations(self, patch_operations):
for op in patch_operations: for op in patch_operations:
if 'variables/' in op['path']: if "variables/" in op["path"]:
self._apply_operation_to_variables(op) self._apply_operation_to_variables(op)
elif '/cacheClusterEnabled' in op['path']: elif "/cacheClusterEnabled" in op["path"]:
self['cacheClusterEnabled'] = self._str2bool(op['value']) self["cacheClusterEnabled"] = self._str2bool(op["value"])
if 'cacheClusterSize' not in self and self['cacheClusterEnabled']: if "cacheClusterSize" not in self and self["cacheClusterEnabled"]:
self['cacheClusterSize'] = str(0.5) self["cacheClusterSize"] = str(0.5)
elif '/cacheClusterSize' in op['path']: elif "/cacheClusterSize" in op["path"]:
self['cacheClusterSize'] = str(float(op['value'])) self["cacheClusterSize"] = str(float(op["value"]))
elif '/description' in op['path']: elif "/description" in op["path"]:
self['description'] = op['value'] self["description"] = op["value"]
elif '/deploymentId' in op['path']: elif "/deploymentId" in op["path"]:
self['deploymentId'] = op['value'] self["deploymentId"] = op["value"]
elif op['op'] == 'replace': elif op["op"] == "replace":
# Method Settings drop into here # Method Settings drop into here
# (e.g., path could be '/*/*/logging/loglevel') # (e.g., path could be '/*/*/logging/loglevel')
split_path = op['path'].split('/', 3) split_path = op["path"].split("/", 3)
if len(split_path) != 4: if len(split_path) != 4:
continue continue
self._patch_method_setting( self._patch_method_setting(
'/'.join(split_path[1:3]), split_path[3], op['value']) "/".join(split_path[1:3]), split_path[3], op["value"]
)
else: else:
raise Exception( raise Exception('Patch operation "%s" not implemented' % op["op"])
'Patch operation "%s" not implemented' % op['op'])
return self return self
def _patch_method_setting(self, resource_path_and_method, key, value): def _patch_method_setting(self, resource_path_and_method, key, value):
updated_key = self._method_settings_translations(key) updated_key = self._method_settings_translations(key)
if updated_key is not None: if updated_key is not None:
if resource_path_and_method not in self['methodSettings']: if resource_path_and_method not in self["methodSettings"]:
self['methodSettings'][ self["methodSettings"][
resource_path_and_method] = self._get_default_method_settings() resource_path_and_method
self['methodSettings'][resource_path_and_method][ ] = self._get_default_method_settings()
updated_key] = self._convert_to_type(updated_key, value) self["methodSettings"][resource_path_and_method][
updated_key
] = self._convert_to_type(updated_key, value)
def _get_default_method_settings(self): def _get_default_method_settings(self):
return { return {
@ -232,21 +255,21 @@ class Stage(BaseModel, dict):
"cacheDataEncrypted": True, "cacheDataEncrypted": True,
"cachingEnabled": False, "cachingEnabled": False,
"throttlingBurstLimit": 2000, "throttlingBurstLimit": 2000,
"requireAuthorizationForCacheControl": True "requireAuthorizationForCacheControl": True,
} }
def _method_settings_translations(self, key): def _method_settings_translations(self, key):
mappings = { mappings = {
'metrics/enabled': 'metricsEnabled', "metrics/enabled": "metricsEnabled",
'logging/loglevel': 'loggingLevel', "logging/loglevel": "loggingLevel",
'logging/dataTrace': 'dataTraceEnabled', "logging/dataTrace": "dataTraceEnabled",
'throttling/burstLimit': 'throttlingBurstLimit', "throttling/burstLimit": "throttlingBurstLimit",
'throttling/rateLimit': 'throttlingRateLimit', "throttling/rateLimit": "throttlingRateLimit",
'caching/enabled': 'cachingEnabled', "caching/enabled": "cachingEnabled",
'caching/ttlInSeconds': 'cacheTtlInSeconds', "caching/ttlInSeconds": "cacheTtlInSeconds",
'caching/dataEncrypted': 'cacheDataEncrypted', "caching/dataEncrypted": "cacheDataEncrypted",
'caching/requireAuthorizationForCacheControl': 'requireAuthorizationForCacheControl', "caching/requireAuthorizationForCacheControl": "requireAuthorizationForCacheControl",
'caching/unauthorizedCacheControlHeaderStrategy': 'unauthorizedCacheControlHeaderStrategy' "caching/unauthorizedCacheControlHeaderStrategy": "unauthorizedCacheControlHeaderStrategy",
} }
if key in mappings: if key in mappings:
@ -259,26 +282,26 @@ class Stage(BaseModel, dict):
def _convert_to_type(self, key, val): def _convert_to_type(self, key, val):
type_mappings = { type_mappings = {
'metricsEnabled': 'bool', "metricsEnabled": "bool",
'loggingLevel': 'str', "loggingLevel": "str",
'dataTraceEnabled': 'bool', "dataTraceEnabled": "bool",
'throttlingBurstLimit': 'int', "throttlingBurstLimit": "int",
'throttlingRateLimit': 'float', "throttlingRateLimit": "float",
'cachingEnabled': 'bool', "cachingEnabled": "bool",
'cacheTtlInSeconds': 'int', "cacheTtlInSeconds": "int",
'cacheDataEncrypted': 'bool', "cacheDataEncrypted": "bool",
'requireAuthorizationForCacheControl': 'bool', "requireAuthorizationForCacheControl": "bool",
'unauthorizedCacheControlHeaderStrategy': 'str' "unauthorizedCacheControlHeaderStrategy": "str",
} }
if key in type_mappings: if key in type_mappings:
type_value = type_mappings[key] type_value = type_mappings[key]
if type_value == 'bool': if type_value == "bool":
return self._str2bool(val) return self._str2bool(val)
elif type_value == 'int': elif type_value == "int":
return int(val) return int(val)
elif type_value == 'float': elif type_value == "float":
return float(val) return float(val)
else: else:
return str(val) return str(val)
@ -286,55 +309,91 @@ class Stage(BaseModel, dict):
return str(val) return str(val)
def _apply_operation_to_variables(self, op): def _apply_operation_to_variables(self, op):
key = op['path'][op['path'].rindex("variables/") + 10:] key = op["path"][op["path"].rindex("variables/") + 10 :]
if op['op'] == 'remove': if op["op"] == "remove":
self['variables'].pop(key, None) self["variables"].pop(key, None)
elif op['op'] == 'replace': elif op["op"] == "replace":
self['variables'][key] = op['value'] self["variables"][key] = op["value"]
else: else:
raise Exception('Patch operation "%s" not implemented' % op['op']) raise Exception('Patch operation "%s" not implemented' % op["op"])
class ApiKey(BaseModel, dict): class ApiKey(BaseModel, dict):
def __init__(
def __init__(self, name=None, description=None, enabled=True, self,
generateDistinctId=False, value=None, stageKeys=None, customerId=None): name=None,
description=None,
enabled=True,
generateDistinctId=False,
value=None,
stageKeys=None,
tags=None,
customerId=None,
):
super(ApiKey, self).__init__() super(ApiKey, self).__init__()
self['id'] = create_id() self["id"] = create_id()
self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40)) self["value"] = (
self['name'] = name value
self['customerId'] = customerId if value
self['description'] = description else "".join(random.sample(string.ascii_letters + string.digits, 40))
self['enabled'] = enabled )
self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) self["name"] = name
self['stageKeys'] = stageKeys self["customerId"] = customerId
self["description"] = description
self["enabled"] = enabled
self["createdDate"] = self["lastUpdatedDate"] = int(time.time())
self["stageKeys"] = stageKeys
self["tags"] = tags
def update_operations(self, patch_operations):
for op in patch_operations:
if op["op"] == "replace":
if "/name" in op["path"]:
self["name"] = op["value"]
elif "/customerId" in op["path"]:
self["customerId"] = op["value"]
elif "/description" in op["path"]:
self["description"] = op["value"]
elif "/enabled" in op["path"]:
self["enabled"] = self._str2bool(op["value"])
else:
raise Exception('Patch operation "%s" not implemented' % op["op"])
return self
def _str2bool(self, v):
return v.lower() == "true"
class UsagePlan(BaseModel, dict): class UsagePlan(BaseModel, dict):
def __init__(
def __init__(self, name=None, description=None, apiStages=[], self,
throttle=None, quota=None): name=None,
description=None,
apiStages=None,
throttle=None,
quota=None,
tags=None,
):
super(UsagePlan, self).__init__() super(UsagePlan, self).__init__()
self['id'] = create_id() self["id"] = create_id()
self['name'] = name self["name"] = name
self['description'] = description self["description"] = description
self['apiStages'] = apiStages self["apiStages"] = apiStages if apiStages else []
self['throttle'] = throttle self["throttle"] = throttle
self['quota'] = quota self["quota"] = quota
self["tags"] = tags
class UsagePlanKey(BaseModel, dict): class UsagePlanKey(BaseModel, dict):
def __init__(self, id, type, name, value): def __init__(self, id, type, name, value):
super(UsagePlanKey, self).__init__() super(UsagePlanKey, self).__init__()
self['id'] = id self["id"] = id
self['name'] = name self["name"] = name
self['type'] = type self["type"] = type
self['value'] = value self["value"] = value
class RestAPI(BaseModel): class RestAPI(BaseModel):
def __init__(self, id, region_name, name, description): def __init__(self, id, region_name, name, description):
self.id = id self.id = id
self.region_name = region_name self.region_name = region_name
@ -346,7 +405,7 @@ class RestAPI(BaseModel):
self.stages = {} self.stages = {}
self.resources = {} self.resources = {}
self.add_child('/') # Add default child self.add_child("/") # Add default child
def __repr__(self): def __repr__(self):
return str(self.id) return str(self.id)
@ -361,8 +420,13 @@ class RestAPI(BaseModel):
def add_child(self, path, parent_id=None): def add_child(self, path, parent_id=None):
child_id = create_id() child_id = create_id()
child = Resource(id=child_id, region_name=self.region_name, child = Resource(
api_id=self.id, path_part=path, parent_id=parent_id) id=child_id,
region_name=self.region_name,
api_id=self.id,
path_part=path,
parent_id=parent_id,
)
self.resources[child_id] = child self.resources[child_id] = child
return child return child
@ -374,30 +438,53 @@ class RestAPI(BaseModel):
def resource_callback(self, request): def resource_callback(self, request):
path = path_url(request.url) path = path_url(request.url)
path_after_stage_name = '/'.join(path.split("/")[2:]) path_after_stage_name = "/".join(path.split("/")[2:])
if not path_after_stage_name: if not path_after_stage_name:
path_after_stage_name = '/' path_after_stage_name = "/"
resource = self.get_resource_for_path(path_after_stage_name) resource = self.get_resource_for_path(path_after_stage_name)
status_code, response = resource.get_response(request) status_code, response = resource.get_response(request)
return status_code, {}, response return status_code, {}, response
def update_integration_mocks(self, stage_name): def update_integration_mocks(self, stage_name):
stage_url_lower = STAGE_URL.format(api_id=self.id.lower(), stage_url_lower = STAGE_URL.format(
region_name=self.region_name, stage_name=stage_name) api_id=self.id.lower(), region_name=self.region_name, stage_name=stage_name
stage_url_upper = STAGE_URL.format(api_id=self.id.upper(), )
region_name=self.region_name, stage_name=stage_name) stage_url_upper = STAGE_URL.format(
api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name
)
responses.add_callback(responses.GET, stage_url_lower, for url in [stage_url_lower, stage_url_upper]:
callback=self.resource_callback) responses._default_mock._matches.insert(
responses.add_callback(responses.GET, stage_url_upper, 0,
callback=self.resource_callback) responses.CallbackResponse(
url=url,
method=responses.GET,
callback=self.resource_callback,
content_type="text/plain",
match_querystring=False,
),
)
def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): def create_stage(
self,
name,
deployment_id,
variables=None,
description="",
cacheClusterEnabled=None,
cacheClusterSize=None,
):
if variables is None: if variables is None:
variables = {} variables = {}
stage = Stage(name=name, deployment_id=deployment_id, variables=variables, stage = Stage(
description=description, cacheClusterSize=cacheClusterSize, cacheClusterEnabled=cacheClusterEnabled) name=name,
deployment_id=deployment_id,
variables=variables,
description=description,
cacheClusterSize=cacheClusterSize,
cacheClusterEnabled=cacheClusterEnabled,
)
self.stages[name] = stage self.stages[name] = stage
self.update_integration_mocks(name) self.update_integration_mocks(name)
return stage return stage
@ -409,7 +496,8 @@ class RestAPI(BaseModel):
deployment = Deployment(deployment_id, name, description) deployment = Deployment(deployment_id, name, description)
self.deployments[deployment_id] = deployment self.deployments[deployment_id] = deployment
self.stages[name] = Stage( self.stages[name] = Stage(
name=name, deployment_id=deployment_id, variables=stage_variables) name=name, deployment_id=deployment_id, variables=stage_variables
)
self.update_integration_mocks(name) self.update_integration_mocks(name)
return deployment return deployment
@ -428,7 +516,6 @@ class RestAPI(BaseModel):
class APIGatewayBackend(BaseBackend): class APIGatewayBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name):
super(APIGatewayBackend, self).__init__() super(APIGatewayBackend, self).__init__()
self.apis = {} self.apis = {}
@ -469,11 +556,10 @@ class APIGatewayBackend(BaseBackend):
return resource return resource
def create_resource(self, function_id, parent_resource_id, path_part): def create_resource(self, function_id, parent_resource_id, path_part):
if not re.match("^\\{?[a-zA-Z0-9._-]+\\}?$", path_part):
raise InvalidResourcePathException()
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
child = api.add_child( child = api.add_child(path=path_part, parent_id=parent_resource_id)
path=path_part,
parent_id=parent_resource_id,
)
return child return child
def delete_resource(self, function_id, resource_id): def delete_resource(self, function_id, resource_id):
@ -502,13 +588,27 @@ class APIGatewayBackend(BaseBackend):
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
return api.get_stages() return api.get_stages()
def create_stage(self, function_id, stage_name, deploymentId, def create_stage(
variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): self,
function_id,
stage_name,
deploymentId,
variables=None,
description="",
cacheClusterEnabled=None,
cacheClusterSize=None,
):
if variables is None: if variables is None:
variables = {} variables = {}
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
api.create_stage(stage_name, deploymentId, variables=variables, api.create_stage(
description=description, cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) stage_name,
deploymentId,
variables=variables,
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
)
return api.stages.get(stage_name) return api.stages.get(stage_name)
def update_stage(self, function_id, stage_name, patch_operations): def update_stage(self, function_id, stage_name, patch_operations):
@ -518,26 +618,73 @@ class APIGatewayBackend(BaseBackend):
stage = api.stages[stage_name] = Stage() stage = api.stages[stage_name] = Stage()
return stage.apply_operations(patch_operations) return stage.apply_operations(patch_operations)
def delete_stage(self, function_id, stage_name):
api = self.get_rest_api(function_id)
del api.stages[stage_name]
def get_method_response(self, function_id, resource_id, method_type, response_code): def get_method_response(self, function_id, resource_id, method_type, response_code):
method = self.get_method(function_id, resource_id, method_type) method = self.get_method(function_id, resource_id, method_type)
method_response = method.get_response(response_code) method_response = method.get_response(response_code)
return method_response return method_response
def create_method_response(self, function_id, resource_id, method_type, response_code): def create_method_response(
self, function_id, resource_id, method_type, response_code
):
method = self.get_method(function_id, resource_id, method_type) method = self.get_method(function_id, resource_id, method_type)
method_response = method.create_response(response_code) method_response = method.create_response(response_code)
return method_response return method_response
def delete_method_response(self, function_id, resource_id, method_type, response_code): def delete_method_response(
self, function_id, resource_id, method_type, response_code
):
method = self.get_method(function_id, resource_id, method_type) method = self.get_method(function_id, resource_id, method_type)
method_response = method.delete_response(response_code) method_response = method.delete_response(response_code)
return method_response return method_response
def create_integration(self, function_id, resource_id, method_type, integration_type, uri, def create_integration(
request_templates=None): self,
function_id,
resource_id,
method_type,
integration_type,
uri,
integration_method=None,
credentials=None,
request_templates=None,
):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
integration = resource.add_integration(method_type, integration_type, uri, if credentials and not re.match(
request_templates=request_templates) "^arn:aws:iam::" + str(ACCOUNT_ID), credentials
):
raise CrossAccountNotAllowed()
if not integration_method and integration_type in [
"HTTP",
"HTTP_PROXY",
"AWS",
"AWS_PROXY",
]:
raise IntegrationMethodNotDefined()
if integration_type in ["AWS_PROXY"] and re.match(
"^arn:aws:apigateway:[a-zA-Z0-9-]+:s3", uri
):
raise AwsProxyNotAllowed()
if (
integration_type in ["AWS"]
and re.match("^arn:aws:apigateway:[a-zA-Z0-9-]+:s3", uri)
and not credentials
):
raise RoleNotSpecified()
if integration_type in ["HTTP", "HTTP_PROXY"] and not self._uri_validator(uri):
raise InvalidHttpEndpoint()
if integration_type in ["AWS", "AWS_PROXY"] and not re.match("^arn:aws:", uri):
raise InvalidArn()
if integration_type in ["AWS", "AWS_PROXY"] and not re.match(
"^arn:aws:apigateway:[a-zA-Z0-9-]+:[a-zA-Z0-9-]+:(path|action)/", uri
):
raise InvalidIntegrationArn()
integration = resource.add_integration(
method_type, integration_type, uri, request_templates=request_templates
)
return integration return integration
def get_integration(self, function_id, resource_id, method_type): def get_integration(self, function_id, resource_id, method_type):
@ -548,31 +695,55 @@ class APIGatewayBackend(BaseBackend):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
return resource.delete_integration(method_type) return resource.delete_integration(method_type)
def create_integration_response(self, function_id, resource_id, method_type, status_code, selection_pattern): def create_integration_response(
integration = self.get_integration( self,
function_id, resource_id, method_type) function_id,
resource_id,
method_type,
status_code,
selection_pattern,
response_templates,
):
if response_templates is None:
raise InvalidRequestInput()
integration = self.get_integration(function_id, resource_id, method_type)
integration_response = integration.create_integration_response( integration_response = integration.create_integration_response(
status_code, selection_pattern) status_code, selection_pattern
)
return integration_response return integration_response
def get_integration_response(self, function_id, resource_id, method_type, status_code): def get_integration_response(
integration = self.get_integration( self, function_id, resource_id, method_type, status_code
function_id, resource_id, method_type) ):
integration_response = integration.get_integration_response( integration = self.get_integration(function_id, resource_id, method_type)
status_code) integration_response = integration.get_integration_response(status_code)
return integration_response return integration_response
def delete_integration_response(self, function_id, resource_id, method_type, status_code): def delete_integration_response(
integration = self.get_integration( self, function_id, resource_id, method_type, status_code
function_id, resource_id, method_type) ):
integration_response = integration.delete_integration_response( integration = self.get_integration(function_id, resource_id, method_type)
status_code) integration_response = integration.delete_integration_response(status_code)
return integration_response return integration_response
def create_deployment(self, function_id, name, description="", stage_variables=None): def create_deployment(
self, function_id, name, description="", stage_variables=None
):
if stage_variables is None: if stage_variables is None:
stage_variables = {} stage_variables = {}
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
methods = [
list(res.resource_methods.values())
for res in self.list_resources(function_id)
][0]
if not any(methods):
raise NoMethodDefined()
method_integrations = [
method["methodIntegration"] if "methodIntegration" in method else None
for method in methods
]
if not any(method_integrations):
raise NoIntegrationDefined()
deployment = api.create_deployment(name, description, stage_variables) deployment = api.create_deployment(name, description, stage_variables)
return deployment return deployment
@ -589,8 +760,12 @@ class APIGatewayBackend(BaseBackend):
return api.delete_deployment(deployment_id) return api.delete_deployment(deployment_id)
def create_apikey(self, payload): def create_apikey(self, payload):
if payload.get("value") is not None:
for api_key in self.get_apikeys():
if api_key.get("value") == payload["value"]:
raise ApiKeyAlreadyExists()
key = ApiKey(**payload) key = ApiKey(**payload)
self.keys[key['id']] = key self.keys[key["id"]] = key
return key return key
def get_apikeys(self): def get_apikeys(self):
@ -599,13 +774,17 @@ class APIGatewayBackend(BaseBackend):
def get_apikey(self, api_key_id): def get_apikey(self, api_key_id):
return self.keys[api_key_id] return self.keys[api_key_id]
def update_apikey(self, api_key_id, patch_operations):
key = self.keys[api_key_id]
return key.update_operations(patch_operations)
def delete_apikey(self, api_key_id): def delete_apikey(self, api_key_id):
self.keys.pop(api_key_id) self.keys.pop(api_key_id)
return {} return {}
def create_usage_plan(self, payload): def create_usage_plan(self, payload):
plan = UsagePlan(**payload) plan = UsagePlan(**payload)
self.usage_plans[plan['id']] = plan self.usage_plans[plan["id"]] = plan
return plan return plan
def get_usage_plans(self, api_key_id=None): def get_usage_plans(self, api_key_id=None):
@ -614,7 +793,7 @@ class APIGatewayBackend(BaseBackend):
plans = [ plans = [
plan plan
for plan in plans for plan in plans
if self.usage_plan_keys.get(plan['id'], {}).get(api_key_id, False) if self.usage_plan_keys.get(plan["id"], {}).get(api_key_id, False)
] ]
return plans return plans
@ -635,8 +814,13 @@ class APIGatewayBackend(BaseBackend):
api_key = self.keys[key_id] api_key = self.keys[key_id]
usage_plan_key = UsagePlanKey(id=key_id, type=payload["keyType"], name=api_key["name"], value=api_key["value"]) usage_plan_key = UsagePlanKey(
self.usage_plan_keys[usage_plan_id][usage_plan_key['id']] = usage_plan_key id=key_id,
type=payload["keyType"],
name=api_key["name"],
value=api_key["value"],
)
self.usage_plan_keys[usage_plan_id][usage_plan_key["id"]] = usage_plan_key
return usage_plan_key return usage_plan_key
def get_usage_plan_keys(self, usage_plan_id): def get_usage_plan_keys(self, usage_plan_id):
@ -652,7 +836,22 @@ class APIGatewayBackend(BaseBackend):
self.usage_plan_keys[usage_plan_id].pop(key_id) self.usage_plan_keys[usage_plan_id].pop(key_id)
return {} return {}
def _uri_validator(self, uri):
try:
result = urlparse(uri)
return all([result.scheme, result.netloc, result.path])
except Exception:
return False
apigateway_backends = {} apigateway_backends = {}
for region_name in Session().get_available_regions('apigateway'): for region_name in Session().get_available_regions("apigateway"):
apigateway_backends[region_name] = APIGatewayBackend(region_name)
for region_name in Session().get_available_regions(
"apigateway", partition_name="aws-us-gov"
):
apigateway_backends[region_name] = APIGatewayBackend(region_name)
for region_name in Session().get_available_regions(
"apigateway", partition_name="aws-cn"
):
apigateway_backends[region_name] = APIGatewayBackend(region_name) apigateway_backends[region_name] = APIGatewayBackend(region_name)

View File

@ -4,13 +4,25 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import apigateway_backends from .models import apigateway_backends
from .exceptions import StageNotFoundException, ApiKeyNotFoundException from .exceptions import (
ApiKeyNotFoundException,
BadRequestException,
CrossAccountNotAllowed,
StageNotFoundException,
ApiKeyAlreadyExists,
)
class APIGatewayResponse(BaseResponse): class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400):
return (
status,
self.response_headers,
json.dumps({"__type": type_, "message": message}),
)
def _get_param(self, key): def _get_param(self, key):
return json.loads(self.body).get(key) return json.loads(self.body).get(key) if self.body else None
def _get_param_with_default_value(self, key, default): def _get_param_with_default_value(self, key, default):
jsonbody = json.loads(self.body) jsonbody = json.loads(self.body)
@ -27,14 +39,12 @@ class APIGatewayResponse(BaseResponse):
def restapis(self, request, full_url, headers): def restapis(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == 'GET': if self.method == "GET":
apis = self.backend.list_apis() apis = self.backend.list_apis()
return 200, {}, json.dumps({"item": [ return 200, {}, json.dumps({"item": [api.to_dict() for api in apis]})
api.to_dict() for api in apis elif self.method == "POST":
]}) name = self._get_param("name")
elif self.method == 'POST': description = self._get_param("description")
name = self._get_param('name')
description = self._get_param('description')
rest_api = self.backend.create_rest_api(name, description) rest_api = self.backend.create_rest_api(name, description)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
@ -42,10 +52,10 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET': if self.method == "GET":
rest_api = self.backend.get_rest_api(function_id) rest_api = self.backend.get_rest_api(function_id)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
elif self.method == 'DELETE': elif self.method == "DELETE":
rest_api = self.backend.delete_rest_api(function_id) rest_api = self.backend.delete_rest_api(function_id)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
@ -53,26 +63,34 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET': if self.method == "GET":
resources = self.backend.list_resources(function_id) resources = self.backend.list_resources(function_id)
return 200, {}, json.dumps({"item": [ return (
resource.to_dict() for resource in resources 200,
]}) {},
json.dumps({"item": [resource.to_dict() for resource in resources]}),
)
def resource_individual(self, request, full_url, headers): def resource_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
resource_id = self.path.split("/")[-1] resource_id = self.path.split("/")[-1]
if self.method == 'GET': try:
resource = self.backend.get_resource(function_id, resource_id) if self.method == "GET":
elif self.method == 'POST': resource = self.backend.get_resource(function_id, resource_id)
path_part = self._get_param("pathPart") elif self.method == "POST":
resource = self.backend.create_resource( path_part = self._get_param("pathPart")
function_id, resource_id, path_part) resource = self.backend.create_resource(
elif self.method == 'DELETE': function_id, resource_id, path_part
resource = self.backend.delete_resource(function_id, resource_id) )
return 200, {}, json.dumps(resource.to_dict()) elif self.method == "DELETE":
resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict())
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
def resource_methods(self, request, full_url, headers): def resource_methods(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -81,14 +99,14 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4] resource_id = url_path_parts[4]
method_type = url_path_parts[6] method_type = url_path_parts[6]
if self.method == 'GET': if self.method == "GET":
method = self.backend.get_method( method = self.backend.get_method(function_id, resource_id, method_type)
function_id, resource_id, method_type)
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
elif self.method == 'PUT': elif self.method == "PUT":
authorization_type = self._get_param("authorizationType") authorization_type = self._get_param("authorizationType")
method = self.backend.create_method( method = self.backend.create_method(
function_id, resource_id, method_type, authorization_type) function_id, resource_id, method_type, authorization_type
)
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
def resource_method_responses(self, request, full_url, headers): def resource_method_responses(self, request, full_url, headers):
@ -99,15 +117,18 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
response_code = url_path_parts[8] response_code = url_path_parts[8]
if self.method == 'GET': if self.method == "GET":
method_response = self.backend.get_method_response( method_response = self.backend.get_method_response(
function_id, resource_id, method_type, response_code) function_id, resource_id, method_type, response_code
elif self.method == 'PUT': )
elif self.method == "PUT":
method_response = self.backend.create_method_response( method_response = self.backend.create_method_response(
function_id, resource_id, method_type, response_code) function_id, resource_id, method_type, response_code
elif self.method == 'DELETE': )
elif self.method == "DELETE":
method_response = self.backend.delete_method_response( method_response = self.backend.delete_method_response(
function_id, resource_id, method_type, response_code) function_id, resource_id, method_type, response_code
)
return 200, {}, json.dumps(method_response) return 200, {}, json.dumps(method_response)
def restapis_stages(self, request, full_url, headers): def restapis_stages(self, request, full_url, headers):
@ -115,21 +136,28 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
function_id = url_path_parts[2] function_id = url_path_parts[2]
if self.method == 'POST': if self.method == "POST":
stage_name = self._get_param("stageName") stage_name = self._get_param("stageName")
deployment_id = self._get_param("deploymentId") deployment_id = self._get_param("deploymentId")
stage_variables = self._get_param_with_default_value( stage_variables = self._get_param_with_default_value("variables", {})
'variables', {}) description = self._get_param_with_default_value("description", "")
description = self._get_param_with_default_value('description', '')
cacheClusterEnabled = self._get_param_with_default_value( cacheClusterEnabled = self._get_param_with_default_value(
'cacheClusterEnabled', False) "cacheClusterEnabled", False
)
cacheClusterSize = self._get_param_with_default_value( cacheClusterSize = self._get_param_with_default_value(
'cacheClusterSize', None) "cacheClusterSize", None
)
stage_response = self.backend.create_stage(function_id, stage_name, deployment_id, stage_response = self.backend.create_stage(
variables=stage_variables, description=description, function_id,
cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) stage_name,
elif self.method == 'GET': deployment_id,
variables=stage_variables,
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
)
elif self.method == "GET":
stages = self.backend.get_stages(function_id) stages = self.backend.get_stages(function_id)
return 200, {}, json.dumps({"item": stages}) return 200, {}, json.dumps({"item": stages})
@ -141,16 +169,25 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2] function_id = url_path_parts[2]
stage_name = url_path_parts[4] stage_name = url_path_parts[4]
if self.method == 'GET': if self.method == "GET":
try: try:
stage_response = self.backend.get_stage( stage_response = self.backend.get_stage(function_id, stage_name)
function_id, stage_name)
except StageNotFoundException as error: except StageNotFoundException as error:
return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type) return (
elif self.method == 'PATCH': error.code,
patch_operations = self._get_param('patchOperations') {},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
stage_response = self.backend.update_stage( stage_response = self.backend.update_stage(
function_id, stage_name, patch_operations) function_id, stage_name, patch_operations
)
elif self.method == "DELETE":
self.backend.delete_stage(function_id, stage_name)
return 202, {}, "{}"
return 200, {}, json.dumps(stage_response) return 200, {}, json.dumps(stage_response)
def integrations(self, request, full_url, headers): def integrations(self, request, full_url, headers):
@ -160,19 +197,40 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4] resource_id = url_path_parts[4]
method_type = url_path_parts[6] method_type = url_path_parts[6]
if self.method == 'GET': try:
integration_response = self.backend.get_integration( if self.method == "GET":
function_id, resource_id, method_type) integration_response = self.backend.get_integration(
elif self.method == 'PUT': function_id, resource_id, method_type
integration_type = self._get_param('type') )
uri = self._get_param('uri') elif self.method == "PUT":
request_templates = self._get_param('requestTemplates') integration_type = self._get_param("type")
integration_response = self.backend.create_integration( uri = self._get_param("uri")
function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates) integration_http_method = self._get_param("httpMethod")
elif self.method == 'DELETE': creds = self._get_param("credentials")
integration_response = self.backend.delete_integration( request_templates = self._get_param("requestTemplates")
function_id, resource_id, method_type) integration_response = self.backend.create_integration(
return 200, {}, json.dumps(integration_response) function_id,
resource_id,
method_type,
integration_type,
uri,
credentials=creds,
integration_method=integration_http_method,
request_templates=request_templates,
)
elif self.method == "DELETE":
integration_response = self.backend.delete_integration(
function_id, resource_id, method_type
)
return 200, {}, json.dumps(integration_response)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
except CrossAccountNotAllowed as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#AccessDeniedException", e.message
)
def integration_responses(self, request, full_url, headers): def integration_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -182,36 +240,52 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
status_code = url_path_parts[9] status_code = url_path_parts[9]
if self.method == 'GET': try:
integration_response = self.backend.get_integration_response( if self.method == "GET":
function_id, resource_id, method_type, status_code integration_response = self.backend.get_integration_response(
function_id, resource_id, method_type, status_code
)
elif self.method == "PUT":
selection_pattern = self._get_param("selectionPattern")
response_templates = self._get_param("responseTemplates")
integration_response = self.backend.create_integration_response(
function_id,
resource_id,
method_type,
status_code,
selection_pattern,
response_templates,
)
elif self.method == "DELETE":
integration_response = self.backend.delete_integration_response(
function_id, resource_id, method_type, status_code
)
return 200, {}, json.dumps(integration_response)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
) )
elif self.method == 'PUT':
selection_pattern = self._get_param("selectionPattern")
integration_response = self.backend.create_integration_response(
function_id, resource_id, method_type, status_code, selection_pattern
)
elif self.method == 'DELETE':
integration_response = self.backend.delete_integration_response(
function_id, resource_id, method_type, status_code
)
return 200, {}, json.dumps(integration_response)
def deployments(self, request, full_url, headers): def deployments(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET': try:
deployments = self.backend.get_deployments(function_id) if self.method == "GET":
return 200, {}, json.dumps({"item": deployments}) deployments = self.backend.get_deployments(function_id)
elif self.method == 'POST': return 200, {}, json.dumps({"item": deployments})
name = self._get_param("stageName") elif self.method == "POST":
description = self._get_param_with_default_value("description", "") name = self._get_param("stageName")
stage_variables = self._get_param_with_default_value( description = self._get_param_with_default_value("description", "")
'variables', {}) stage_variables = self._get_param_with_default_value("variables", {})
deployment = self.backend.create_deployment( deployment = self.backend.create_deployment(
function_id, name, description, stage_variables) function_id, name, description, stage_variables
return 200, {}, json.dumps(deployment) )
return 200, {}, json.dumps(deployment)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
def individual_deployment(self, request, full_url, headers): def individual_deployment(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -219,20 +293,28 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2] function_id = url_path_parts[2]
deployment_id = url_path_parts[4] deployment_id = url_path_parts[4]
if self.method == 'GET': if self.method == "GET":
deployment = self.backend.get_deployment( deployment = self.backend.get_deployment(function_id, deployment_id)
function_id, deployment_id) elif self.method == "DELETE":
elif self.method == 'DELETE': deployment = self.backend.delete_deployment(function_id, deployment_id)
deployment = self.backend.delete_deployment(
function_id, deployment_id)
return 200, {}, json.dumps(deployment) return 200, {}, json.dumps(deployment)
def apikeys(self, request, full_url, headers): def apikeys(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == 'POST': if self.method == "POST":
apikey_response = self.backend.create_apikey(json.loads(self.body)) try:
elif self.method == 'GET': apikey_response = self.backend.create_apikey(json.loads(self.body))
except ApiKeyAlreadyExists as error:
return (
error.code,
self.headers,
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "GET":
apikeys_response = self.backend.get_apikeys() apikeys_response = self.backend.get_apikeys()
return 200, {}, json.dumps({"item": apikeys_response}) return 200, {}, json.dumps({"item": apikeys_response})
return 200, {}, json.dumps(apikey_response) return 200, {}, json.dumps(apikey_response)
@ -243,18 +325,21 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
apikey = url_path_parts[2] apikey = url_path_parts[2]
if self.method == 'GET': if self.method == "GET":
apikey_response = self.backend.get_apikey(apikey) apikey_response = self.backend.get_apikey(apikey)
elif self.method == 'DELETE': elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
apikey_response = self.backend.update_apikey(apikey, patch_operations)
elif self.method == "DELETE":
apikey_response = self.backend.delete_apikey(apikey) apikey_response = self.backend.delete_apikey(apikey)
return 200, {}, json.dumps(apikey_response) return 200, {}, json.dumps(apikey_response)
def usage_plans(self, request, full_url, headers): def usage_plans(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == 'POST': if self.method == "POST":
usage_plan_response = self.backend.create_usage_plan(json.loads(self.body)) usage_plan_response = self.backend.create_usage_plan(json.loads(self.body))
elif self.method == 'GET': elif self.method == "GET":
api_key_id = self.querystring.get("keyId", [None])[0] api_key_id = self.querystring.get("keyId", [None])[0]
usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id) usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id)
return 200, {}, json.dumps({"item": usage_plans_response}) return 200, {}, json.dumps({"item": usage_plans_response})
@ -266,9 +351,9 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
usage_plan = url_path_parts[2] usage_plan = url_path_parts[2]
if self.method == 'GET': if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan(usage_plan) usage_plan_response = self.backend.get_usage_plan(usage_plan)
elif self.method == 'DELETE': elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan(usage_plan) usage_plan_response = self.backend.delete_usage_plan(usage_plan)
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)
@ -278,13 +363,21 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
usage_plan_id = url_path_parts[2] usage_plan_id = url_path_parts[2]
if self.method == 'POST': if self.method == "POST":
try: try:
usage_plan_response = self.backend.create_usage_plan_key(usage_plan_id, json.loads(self.body)) usage_plan_response = self.backend.create_usage_plan_key(
usage_plan_id, json.loads(self.body)
)
except ApiKeyNotFoundException as error: except ApiKeyNotFoundException as error:
return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type) return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == 'GET': elif self.method == "GET":
usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id) usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id)
return 200, {}, json.dumps({"item": usage_plans_response}) return 200, {}, json.dumps({"item": usage_plans_response})
@ -297,8 +390,10 @@ class APIGatewayResponse(BaseResponse):
usage_plan_id = url_path_parts[2] usage_plan_id = url_path_parts[2]
key_id = url_path_parts[4] key_id = url_path_parts[4]
if self.method == 'GET': if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id) usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id)
elif self.method == 'DELETE': elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan_key(usage_plan_id, key_id) usage_plan_response = self.backend.delete_usage_plan_key(
usage_plan_id, key_id
)
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)

View File

@ -1,27 +1,25 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import APIGatewayResponse from .responses import APIGatewayResponse
url_bases = [ url_bases = ["https?://apigateway.(.+).amazonaws.com"]
"https?://apigateway.(.+).amazonaws.com"
]
url_paths = { url_paths = {
'{0}/restapis$': APIGatewayResponse().restapis, "{0}/restapis$": APIGatewayResponse().restapis,
'{0}/restapis/(?P<function_id>[^/]+)/?$': APIGatewayResponse().restapis_individual, "{0}/restapis/(?P<function_id>[^/]+)/?$": APIGatewayResponse().restapis_individual,
'{0}/restapis/(?P<function_id>[^/]+)/resources$': APIGatewayResponse().resources, "{0}/restapis/(?P<function_id>[^/]+)/resources$": APIGatewayResponse().resources,
'{0}/restapis/(?P<function_id>[^/]+)/stages$': APIGatewayResponse().restapis_stages, "{0}/restapis/(?P<function_id>[^/]+)/stages$": APIGatewayResponse().restapis_stages,
'{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$': APIGatewayResponse().stages, "{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$": APIGatewayResponse().stages,
'{0}/restapis/(?P<function_id>[^/]+)/deployments$': APIGatewayResponse().deployments, "{0}/restapis/(?P<function_id>[^/]+)/deployments$": APIGatewayResponse().deployments,
'{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$': APIGatewayResponse().individual_deployment, "{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$": APIGatewayResponse().individual_deployment,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$': APIGatewayResponse().resource_individual, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$": APIGatewayResponse().resource_individual,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$': APIGatewayResponse().resource_methods, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$": APIGatewayResponse().resource_methods,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$': APIGatewayResponse().resource_method_responses, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$": APIGatewayResponse().resource_method_responses,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$': APIGatewayResponse().integrations, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$": APIGatewayResponse().integrations,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$': APIGatewayResponse().integration_responses, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$": APIGatewayResponse().integration_responses,
'{0}/apikeys$': APIGatewayResponse().apikeys, "{0}/apikeys$": APIGatewayResponse().apikeys,
'{0}/apikeys/(?P<apikey>[^/]+)': APIGatewayResponse().apikey_individual, "{0}/apikeys/(?P<apikey>[^/]+)": APIGatewayResponse().apikey_individual,
'{0}/usageplans$': APIGatewayResponse().usage_plans, "{0}/usageplans$": APIGatewayResponse().usage_plans,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$': APIGatewayResponse().usage_plan_individual, "{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$": APIGatewayResponse().usage_plan_individual,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$': APIGatewayResponse().usage_plan_keys, "{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$": APIGatewayResponse().usage_plan_keys,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$': APIGatewayResponse().usage_plan_key_individual, "{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual,
} }

View File

@ -1,9 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import six import six
import random import random
import string
def create_id(): def create_id():
size = 10 size = 10
chars = list(range(10)) + ['A-Z'] chars = list(range(10)) + list(string.ascii_lowercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size)) return "".join(six.text_type(random.choice(chars)) for x in range(size))

7
moto/athena/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from __future__ import unicode_literals
from .models import athena_backends
from ..core.models import base_decorator, deprecated_base_decorator
athena_backend = athena_backends["us-east-1"]
mock_athena = base_decorator(athena_backends)
mock_athena_deprecated = deprecated_base_decorator(athena_backends)

19
moto/athena/exceptions.py Normal file
View File

@ -0,0 +1,19 @@
from __future__ import unicode_literals
import json
from werkzeug.exceptions import BadRequest
class AthenaClientError(BadRequest):
def __init__(self, code, message):
super(AthenaClientError, self).__init__()
self.description = json.dumps(
{
"Error": {
"Code": code,
"Message": message,
"Type": "InvalidRequestException",
},
"RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1",
}
)

86
moto/athena/models.py Normal file
View File

@ -0,0 +1,86 @@
from __future__ import unicode_literals
import time
from boto3 import Session
from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID
class TaggableResourceMixin(object):
# This mixing was copied from Redshift when initially implementing
# Athena. TBD if it's worth the overhead.
def __init__(self, region_name, resource_name, tags):
self.region = region_name
self.resource_name = resource_name
self.tags = tags or []
@property
def arn(self):
return "arn:aws:athena:{region}:{account_id}:{resource_name}".format(
region=self.region, account_id=ACCOUNT_ID, resource_name=self.resource_name
)
def create_tags(self, tags):
new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
self.tags.extend(tags)
return self.tags
def delete_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
return self.tags
class WorkGroup(TaggableResourceMixin, BaseModel):
resource_type = "workgroup"
state = "ENABLED"
def __init__(self, athena_backend, name, configuration, description, tags):
self.region_name = athena_backend.region_name
super(WorkGroup, self).__init__(
self.region_name, "workgroup/{}".format(name), tags
)
self.athena_backend = athena_backend
self.name = name
self.description = description
self.configuration = configuration
class AthenaBackend(BaseBackend):
region_name = None
def __init__(self, region_name=None):
if region_name is not None:
self.region_name = region_name
self.work_groups = {}
def create_work_group(self, name, configuration, description, tags):
if name in self.work_groups:
return None
work_group = WorkGroup(self, name, configuration, description, tags)
self.work_groups[name] = work_group
return work_group
def list_work_groups(self):
return [
{
"Name": wg.name,
"State": wg.state,
"Description": wg.description,
"CreationTime": time.time(),
}
for wg in self.work_groups.values()
]
athena_backends = {}
for region in Session().get_available_regions("athena"):
athena_backends[region] = AthenaBackend(region)
for region in Session().get_available_regions("athena", partition_name="aws-us-gov"):
athena_backends[region] = AthenaBackend(region)
for region in Session().get_available_regions("athena", partition_name="aws-cn"):
athena_backends[region] = AthenaBackend(region)

41
moto/athena/responses.py Normal file
View File

@ -0,0 +1,41 @@
import json
from moto.core.responses import BaseResponse
from .models import athena_backends
class AthenaResponse(BaseResponse):
@property
def athena_backend(self):
return athena_backends[self.region]
def create_work_group(self):
name = self._get_param("Name")
description = self._get_param("Description")
configuration = self._get_param("Configuration")
tags = self._get_param("Tags")
work_group = self.athena_backend.create_work_group(
name, configuration, description, tags
)
if not work_group:
return (
json.dumps(
{
"__type": "InvalidRequestException",
"Message": "WorkGroup already exists",
}
),
dict(status=400),
)
return json.dumps(
{
"CreateWorkGroupResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a"
}
}
}
)
def list_work_groups(self):
return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()})

6
moto/athena/urls.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .responses import AthenaResponse
url_bases = ["https?://athena.(.+).amazonaws.com"]
url_paths = {"{0}/$": AthenaResponse.dispatch}

1
moto/athena/utils.py Normal file
View File

@ -0,0 +1 @@
from __future__ import unicode_literals

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import autoscaling_backends from .models import autoscaling_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
autoscaling_backend = autoscaling_backends['us-east-1'] autoscaling_backend = autoscaling_backends["us-east-1"]
mock_autoscaling = base_decorator(autoscaling_backends) mock_autoscaling = base_decorator(autoscaling_backends)
mock_autoscaling_deprecated = deprecated_base_decorator(autoscaling_backends) mock_autoscaling_deprecated = deprecated_base_decorator(autoscaling_backends)

View File

@ -12,4 +12,12 @@ class ResourceContentionError(RESTError):
def __init__(self): def __init__(self):
super(ResourceContentionError, self).__init__( super(ResourceContentionError, self).__init__(
"ResourceContentionError", "ResourceContentionError",
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).") "You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).",
)
class InvalidInstanceError(AutoscalingClientError):
def __init__(self, instance_id):
super(InvalidInstanceError, self).__init__(
"ValidationError", "Instance [{0}] is invalid.".format(instance_id)
)

View File

@ -1,5 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import random
from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping
from moto.ec2.exceptions import InvalidInstanceIdError
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
@ -7,7 +12,9 @@ from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends from moto.elbv2 import elbv2_backends
from moto.elb.exceptions import LoadBalancerNotFoundError from moto.elb.exceptions import LoadBalancerNotFoundError
from .exceptions import ( from .exceptions import (
AutoscalingClientError, ResourceContentionError, AutoscalingClientError,
ResourceContentionError,
InvalidInstanceError,
) )
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown # http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
@ -17,8 +24,13 @@ ASG_NAME_TAG = "aws:autoscaling:groupName"
class InstanceState(object): class InstanceState(object):
def __init__(self, instance, lifecycle_state="InService", def __init__(
health_status="Healthy", protected_from_scale_in=False): self,
instance,
lifecycle_state="InService",
health_status="Healthy",
protected_from_scale_in=False,
):
self.instance = instance self.instance = instance
self.lifecycle_state = lifecycle_state self.lifecycle_state = lifecycle_state
self.health_status = health_status self.health_status = health_status
@ -26,8 +38,16 @@ class InstanceState(object):
class FakeScalingPolicy(BaseModel): class FakeScalingPolicy(BaseModel):
def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment, def __init__(
cooldown, autoscaling_backend): self,
name,
policy_type,
adjustment_type,
as_name,
scaling_adjustment,
cooldown,
autoscaling_backend,
):
self.name = name self.name = name
self.policy_type = policy_type self.policy_type = policy_type
self.adjustment_type = adjustment_type self.adjustment_type = adjustment_type
@ -40,21 +60,38 @@ class FakeScalingPolicy(BaseModel):
self.autoscaling_backend = autoscaling_backend self.autoscaling_backend = autoscaling_backend
def execute(self): def execute(self):
if self.adjustment_type == 'ExactCapacity': if self.adjustment_type == "ExactCapacity":
self.autoscaling_backend.set_desired_capacity( self.autoscaling_backend.set_desired_capacity(
self.as_name, self.scaling_adjustment) self.as_name, self.scaling_adjustment
elif self.adjustment_type == 'ChangeInCapacity': )
elif self.adjustment_type == "ChangeInCapacity":
self.autoscaling_backend.change_capacity( self.autoscaling_backend.change_capacity(
self.as_name, self.scaling_adjustment) self.as_name, self.scaling_adjustment
elif self.adjustment_type == 'PercentChangeInCapacity': )
elif self.adjustment_type == "PercentChangeInCapacity":
self.autoscaling_backend.change_capacity_percent( self.autoscaling_backend.change_capacity_percent(
self.as_name, self.scaling_adjustment) self.as_name, self.scaling_adjustment
)
class FakeLaunchConfiguration(BaseModel): class FakeLaunchConfiguration(BaseModel):
def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data, def __init__(
instance_type, instance_monitoring, instance_profile_name, self,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict): name,
image_id,
key_name,
ramdisk_id,
kernel_id,
security_groups,
user_data,
instance_type,
instance_monitoring,
instance_profile_name,
spot_price,
ebs_optimized,
associate_public_ip_address,
block_device_mapping_dict,
):
self.name = name self.name = name
self.image_id = image_id self.image_id = image_id
self.key_name = key_name self.key_name = key_name
@ -71,8 +108,30 @@ class FakeLaunchConfiguration(BaseModel):
self.block_device_mapping_dict = block_device_mapping_dict self.block_device_mapping_dict = block_device_mapping_dict
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_instance(cls, name, instance, backend):
properties = cloudformation_json['Properties'] config = backend.create_launch_configuration(
name=name,
image_id=instance.image_id,
kernel_id="",
ramdisk_id="",
key_name=instance.key_name,
security_groups=instance.security_groups,
user_data=instance.user_data,
instance_type=instance.instance_type,
instance_monitoring=False,
instance_profile_name=None,
spot_price=None,
ebs_optimized=instance.ebs_optimized,
associate_public_ip_address=instance.associate_public_ip,
block_device_mappings=instance.block_device_mapping,
)
return config
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
instance_profile_name = properties.get("IamInstanceProfile") instance_profile_name = properties.get("IamInstanceProfile")
@ -90,20 +149,26 @@ class FakeLaunchConfiguration(BaseModel):
instance_profile_name=instance_profile_name, instance_profile_name=instance_profile_name,
spot_price=properties.get("SpotPrice"), spot_price=properties.get("SpotPrice"),
ebs_optimized=properties.get("EbsOptimized"), ebs_optimized=properties.get("EbsOptimized"),
associate_public_ip_address=properties.get( associate_public_ip_address=properties.get("AssociatePublicIpAddress"),
"AssociatePublicIpAddress"), block_device_mappings=properties.get("BlockDeviceMapping.member"),
block_device_mappings=properties.get("BlockDeviceMapping.member")
) )
return config return config
@classmethod @classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
cls.delete_from_cloudformation_json( cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name) original_resource.name, cloudformation_json, region_name
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) )
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
@classmethod @classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = autoscaling_backends[region_name] backend = autoscaling_backends[region_name]
try: try:
backend.delete_launch_configuration(resource_name) backend.delete_launch_configuration(resource_name)
@ -128,69 +193,116 @@ class FakeLaunchConfiguration(BaseModel):
@property @property
def instance_monitoring_enabled(self): def instance_monitoring_enabled(self):
if self.instance_monitoring: if self.instance_monitoring:
return 'true' return "true"
return 'false' return "false"
def _parse_block_device_mappings(self): def _parse_block_device_mappings(self):
block_device_map = BlockDeviceMapping() block_device_map = BlockDeviceMapping()
for mapping in self.block_device_mapping_dict: for mapping in self.block_device_mapping_dict:
block_type = BlockDeviceType() block_type = BlockDeviceType()
mount_point = mapping.get('device_name') mount_point = mapping.get("device_name")
if 'ephemeral' in mapping.get('virtual_name', ''): if "ephemeral" in mapping.get("virtual_name", ""):
block_type.ephemeral_name = mapping.get('virtual_name') block_type.ephemeral_name = mapping.get("virtual_name")
else: else:
block_type.volume_type = mapping.get('ebs._volume_type') block_type.volume_type = mapping.get("ebs._volume_type")
block_type.snapshot_id = mapping.get('ebs._snapshot_id') block_type.snapshot_id = mapping.get("ebs._snapshot_id")
block_type.delete_on_termination = mapping.get( block_type.delete_on_termination = mapping.get(
'ebs._delete_on_termination') "ebs._delete_on_termination"
block_type.size = mapping.get('ebs._volume_size') )
block_type.iops = mapping.get('ebs._iops') block_type.size = mapping.get("ebs._volume_size")
block_type.iops = mapping.get("ebs._iops")
block_device_map[mount_point] = block_type block_device_map[mount_point] = block_type
return block_device_map return block_device_map
class FakeAutoScalingGroup(BaseModel): class FakeAutoScalingGroup(BaseModel):
def __init__(self, name, availability_zones, desired_capacity, max_size, def __init__(
min_size, launch_config_name, vpc_zone_identifier, self,
default_cooldown, health_check_period, health_check_type, name,
load_balancers, target_group_arns, placement_group, termination_policies, availability_zones,
autoscaling_backend, tags, desired_capacity,
new_instances_protected_from_scale_in=False): max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
load_balancers,
target_group_arns,
placement_group,
termination_policies,
autoscaling_backend,
tags,
new_instances_protected_from_scale_in=False,
):
self.autoscaling_backend = autoscaling_backend self.autoscaling_backend = autoscaling_backend
self.name = name self.name = name
if not availability_zones and not vpc_zone_identifier: self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier)
raise AutoscalingClientError(
"ValidationError",
"At least one Availability Zone or VPC Subnet is required."
)
self.availability_zones = availability_zones
self.vpc_zone_identifier = vpc_zone_identifier
self.max_size = max_size self.max_size = max_size
self.min_size = min_size self.min_size = min_size
self.launch_config = self.autoscaling_backend.launch_configurations[ self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name] launch_config_name
]
self.launch_config_name = launch_config_name self.launch_config_name = launch_config_name
self.default_cooldown = default_cooldown if default_cooldown else DEFAULT_COOLDOWN self.default_cooldown = (
default_cooldown if default_cooldown else DEFAULT_COOLDOWN
)
self.health_check_period = health_check_period self.health_check_period = health_check_period
self.health_check_type = health_check_type if health_check_type else "EC2" self.health_check_type = health_check_type if health_check_type else "EC2"
self.load_balancers = load_balancers self.load_balancers = load_balancers
self.target_group_arns = target_group_arns self.target_group_arns = target_group_arns
self.placement_group = placement_group self.placement_group = placement_group
self.termination_policies = termination_policies self.termination_policies = termination_policies
self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in self.new_instances_protected_from_scale_in = (
new_instances_protected_from_scale_in
)
self.suspended_processes = [] self.suspended_processes = []
self.instance_states = [] self.instance_states = []
self.tags = tags if tags else [] self.tags = tags if tags else []
self.set_desired_capacity(desired_capacity) self.set_desired_capacity(desired_capacity)
def _set_azs_and_vpcs(self, availability_zones, vpc_zone_identifier, update=False):
# for updates, if only AZs are provided, they must not clash with
# the AZs of existing VPCs
if update and availability_zones and not vpc_zone_identifier:
vpc_zone_identifier = self.vpc_zone_identifier
if vpc_zone_identifier:
# extract azs for vpcs
subnet_ids = vpc_zone_identifier.split(",")
subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(
subnet_ids=subnet_ids
)
vpc_zones = [subnet.availability_zone for subnet in subnets]
if availability_zones and set(availability_zones) != set(vpc_zones):
raise AutoscalingClientError(
"ValidationError",
"The availability zones of the specified subnets and the Auto Scaling group do not match",
)
availability_zones = vpc_zones
elif not availability_zones:
if not update:
raise AutoscalingClientError(
"ValidationError",
"At least one Availability Zone or VPC Subnet is required.",
)
return
self.availability_zones = availability_zones
self.vpc_zone_identifier = vpc_zone_identifier
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(
properties = cloudformation_json['Properties'] cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
launch_config_name = properties.get("LaunchConfigurationName") launch_config_name = properties.get("LaunchConfigurationName")
load_balancer_names = properties.get("LoadBalancerNames", []) load_balancer_names = properties.get("LoadBalancerNames", [])
@ -205,7 +317,8 @@ class FakeAutoScalingGroup(BaseModel):
min_size=properties.get("MinSize"), min_size=properties.get("MinSize"),
launch_config_name=launch_config_name, launch_config_name=launch_config_name,
vpc_zone_identifier=( vpc_zone_identifier=(
','.join(properties.get("VPCZoneIdentifier", [])) or None), ",".join(properties.get("VPCZoneIdentifier", [])) or None
),
default_cooldown=properties.get("Cooldown"), default_cooldown=properties.get("Cooldown"),
health_check_period=properties.get("HealthCheckGracePeriod"), health_check_period=properties.get("HealthCheckGracePeriod"),
health_check_type=properties.get("HealthCheckType"), health_check_type=properties.get("HealthCheckType"),
@ -215,18 +328,26 @@ class FakeAutoScalingGroup(BaseModel):
termination_policies=properties.get("TerminationPolicies", []), termination_policies=properties.get("TerminationPolicies", []),
tags=properties.get("Tags", []), tags=properties.get("Tags", []),
new_instances_protected_from_scale_in=properties.get( new_instances_protected_from_scale_in=properties.get(
"NewInstancesProtectedFromScaleIn", False) "NewInstancesProtectedFromScaleIn", False
),
) )
return group return group
@classmethod @classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
cls.delete_from_cloudformation_json( cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name) original_resource.name, cloudformation_json, region_name
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) )
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
@classmethod @classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = autoscaling_backends[region_name] backend = autoscaling_backends[region_name]
try: try:
backend.delete_auto_scaling_group(resource_name) backend.delete_auto_scaling_group(resource_name)
@ -241,30 +362,47 @@ class FakeAutoScalingGroup(BaseModel):
def physical_resource_id(self): def physical_resource_id(self):
return self.name return self.name
def update(self, availability_zones, desired_capacity, max_size, min_size, def update(
launch_config_name, vpc_zone_identifier, default_cooldown, self,
health_check_period, health_check_type, availability_zones,
placement_group, termination_policies, desired_capacity,
new_instances_protected_from_scale_in=None): max_size,
if availability_zones: min_size,
self.availability_zones = availability_zones launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=None,
):
self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier, update=True)
if max_size is not None: if max_size is not None:
self.max_size = max_size self.max_size = max_size
if min_size is not None: if min_size is not None:
self.min_size = min_size self.min_size = min_size
if desired_capacity is None:
if min_size is not None and min_size > len(self.instance_states):
desired_capacity = min_size
if max_size is not None and max_size < len(self.instance_states):
desired_capacity = max_size
if launch_config_name: if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[ self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name] launch_config_name
]
self.launch_config_name = launch_config_name self.launch_config_name = launch_config_name
if vpc_zone_identifier is not None:
self.vpc_zone_identifier = vpc_zone_identifier
if health_check_period is not None: if health_check_period is not None:
self.health_check_period = health_check_period self.health_check_period = health_check_period
if health_check_type is not None: if health_check_type is not None:
self.health_check_type = health_check_type self.health_check_type = health_check_type
if new_instances_protected_from_scale_in is not None: if new_instances_protected_from_scale_in is not None:
self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in self.new_instances_protected_from_scale_in = (
new_instances_protected_from_scale_in
)
if desired_capacity is not None: if desired_capacity is not None:
self.set_desired_capacity(desired_capacity) self.set_desired_capacity(desired_capacity)
@ -290,25 +428,30 @@ class FakeAutoScalingGroup(BaseModel):
# Need to remove some instances # Need to remove some instances
count_to_remove = curr_instance_count - self.desired_capacity count_to_remove = curr_instance_count - self.desired_capacity
instances_to_remove = [ # only remove unprotected instances_to_remove = [ # only remove unprotected
state for state in self.instance_states state
for state in self.instance_states
if not state.protected_from_scale_in if not state.protected_from_scale_in
][:count_to_remove] ][:count_to_remove]
if instances_to_remove: # just in case not instances to remove if instances_to_remove: # just in case not instances to remove
instance_ids_to_remove = [ instance_ids_to_remove = [
instance.instance.id for instance in instances_to_remove] instance.instance.id for instance in instances_to_remove
]
self.autoscaling_backend.ec2_backend.terminate_instances( self.autoscaling_backend.ec2_backend.terminate_instances(
instance_ids_to_remove) instance_ids_to_remove
self.instance_states = list(set(self.instance_states) - set(instances_to_remove)) )
self.instance_states = list(
set(self.instance_states) - set(instances_to_remove)
)
def get_propagated_tags(self): def get_propagated_tags(self):
propagated_tags = {} propagated_tags = {}
for tag in self.tags: for tag in self.tags:
# boto uses 'propagate_at_launch # boto uses 'propagate_at_launch
# boto3 and cloudformation use PropagateAtLaunch # boto3 and cloudformation use PropagateAtLaunch
if 'propagate_at_launch' in tag and tag['propagate_at_launch'] == 'true': if "propagate_at_launch" in tag and tag["propagate_at_launch"] == "true":
propagated_tags[tag['key']] = tag['value'] propagated_tags[tag["key"]] = tag["value"]
if 'PropagateAtLaunch' in tag and tag['PropagateAtLaunch']: if "PropagateAtLaunch" in tag and tag["PropagateAtLaunch"]:
propagated_tags[tag['Key']] = tag['Value'] propagated_tags[tag["Key"]] = tag["Value"]
return propagated_tags return propagated_tags
def replace_autoscaling_group_instances(self, count_needed, propagated_tags): def replace_autoscaling_group_instances(self, count_needed, propagated_tags):
@ -319,14 +462,17 @@ class FakeAutoScalingGroup(BaseModel):
self.launch_config.user_data, self.launch_config.user_data,
self.launch_config.security_groups, self.launch_config.security_groups,
instance_type=self.launch_config.instance_type, instance_type=self.launch_config.instance_type,
tags={'instance': propagated_tags} tags={"instance": propagated_tags},
placement=random.choice(self.availability_zones),
) )
for instance in reservation.instances: for instance in reservation.instances:
instance.autoscaling_group = self instance.autoscaling_group = self
self.instance_states.append(InstanceState( self.instance_states.append(
instance, InstanceState(
protected_from_scale_in=self.new_instances_protected_from_scale_in, instance,
)) protected_from_scale_in=self.new_instances_protected_from_scale_in,
)
)
def append_target_groups(self, target_group_arns): def append_target_groups(self, target_group_arns):
append = [x for x in target_group_arns if x not in self.target_group_arns] append = [x for x in target_group_arns if x not in self.target_group_arns]
@ -349,10 +495,23 @@ class AutoScalingBackend(BaseBackend):
self.__dict__ = {} self.__dict__ = {}
self.__init__(ec2_backend, elb_backend, elbv2_backend) self.__init__(ec2_backend, elb_backend, elbv2_backend)
def create_launch_configuration(self, name, image_id, key_name, kernel_id, ramdisk_id, def create_launch_configuration(
security_groups, user_data, instance_type, self,
instance_monitoring, instance_profile_name, name,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mappings): image_id,
key_name,
kernel_id,
ramdisk_id,
security_groups,
user_data,
instance_type,
instance_monitoring,
instance_profile_name,
spot_price,
ebs_optimized,
associate_public_ip_address,
block_device_mappings,
):
launch_configuration = FakeLaunchConfiguration( launch_configuration = FakeLaunchConfiguration(
name=name, name=name,
image_id=image_id, image_id=image_id,
@ -375,22 +534,37 @@ class AutoScalingBackend(BaseBackend):
def describe_launch_configurations(self, names): def describe_launch_configurations(self, names):
configurations = self.launch_configurations.values() configurations = self.launch_configurations.values()
if names: if names:
return [configuration for configuration in configurations if configuration.name in names] return [
configuration
for configuration in configurations
if configuration.name in names
]
else: else:
return list(configurations) return list(configurations)
def delete_launch_configuration(self, launch_configuration_name): def delete_launch_configuration(self, launch_configuration_name):
self.launch_configurations.pop(launch_configuration_name, None) self.launch_configurations.pop(launch_configuration_name, None)
def create_auto_scaling_group(self, name, availability_zones, def create_auto_scaling_group(
desired_capacity, max_size, min_size, self,
launch_config_name, vpc_zone_identifier, name,
default_cooldown, health_check_period, availability_zones,
health_check_type, load_balancers, desired_capacity,
target_group_arns, placement_group, max_size,
termination_policies, tags, min_size,
new_instances_protected_from_scale_in=False): launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
load_balancers,
target_group_arns,
placement_group,
termination_policies,
tags,
new_instances_protected_from_scale_in=False,
instance_id=None,
):
def make_int(value): def make_int(value):
return int(value) if value is not None else value return int(value) if value is not None else value
@ -402,6 +576,15 @@ class AutoScalingBackend(BaseBackend):
health_check_period = 300 health_check_period = 300
else: else:
health_check_period = make_int(health_check_period) health_check_period = make_int(health_check_period)
if launch_config_name is None and instance_id is not None:
try:
instance = self.ec2_backend.get_instance(instance_id)
launch_config_name = name
FakeLaunchConfiguration.create_from_instance(
launch_config_name, instance, self
)
except InvalidInstanceIdError:
raise InvalidInstanceError(instance_id)
group = FakeAutoScalingGroup( group = FakeAutoScalingGroup(
name=name, name=name,
@ -428,19 +611,37 @@ class AutoScalingBackend(BaseBackend):
self.update_attached_target_groups(group.name) self.update_attached_target_groups(group.name)
return group return group
def update_auto_scaling_group(self, name, availability_zones, def update_auto_scaling_group(
desired_capacity, max_size, min_size, self,
launch_config_name, vpc_zone_identifier, name,
default_cooldown, health_check_period, availability_zones,
health_check_type, placement_group, desired_capacity,
termination_policies, max_size,
new_instances_protected_from_scale_in=None): min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=None,
):
group = self.autoscaling_groups[name] group = self.autoscaling_groups[name]
group.update(availability_zones, desired_capacity, max_size, group.update(
min_size, launch_config_name, vpc_zone_identifier, availability_zones,
default_cooldown, health_check_period, health_check_type, desired_capacity,
placement_group, termination_policies, max_size,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in) min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in,
)
return group return group
def describe_auto_scaling_groups(self, names): def describe_auto_scaling_groups(self, names):
@ -476,32 +677,48 @@ class AutoScalingBackend(BaseBackend):
for x in instance_ids for x in instance_ids
] ]
for instance in new_instances: for instance in new_instances:
self.ec2_backend.create_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) self.ec2_backend.create_tags(
[instance.instance.id], {ASG_NAME_TAG: group.name}
)
group.instance_states.extend(new_instances) group.instance_states.extend(new_instances)
self.update_attached_elbs(group.name) self.update_attached_elbs(group.name)
def set_instance_health(self, instance_id, health_status, should_respect_grace_period): def set_instance_health(
self, instance_id, health_status, should_respect_grace_period
):
instance = self.ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
instance_state = next(instance_state for group in self.autoscaling_groups.values() instance_state = next(
for instance_state in group.instance_states if instance_state.instance.id == instance.id) instance_state
for group in self.autoscaling_groups.values()
for instance_state in group.instance_states
if instance_state.instance.id == instance.id
)
instance_state.health_status = health_status instance_state.health_status = health_status
def detach_instances(self, group_name, instance_ids, should_decrement): def detach_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
original_size = len(group.instance_states) original_size = len(group.instance_states)
detached_instances = [x for x in group.instance_states if x.instance.id in instance_ids] detached_instances = [
x for x in group.instance_states if x.instance.id in instance_ids
]
for instance in detached_instances: for instance in detached_instances:
self.ec2_backend.delete_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) self.ec2_backend.delete_tags(
[instance.instance.id], {ASG_NAME_TAG: group.name}
)
new_instance_state = [x for x in group.instance_states if x.instance.id not in instance_ids] new_instance_state = [
x for x in group.instance_states if x.instance.id not in instance_ids
]
group.instance_states = new_instance_state group.instance_states = new_instance_state
if should_decrement: if should_decrement:
group.desired_capacity = original_size - len(instance_ids) group.desired_capacity = original_size - len(instance_ids)
else: else:
count_needed = len(instance_ids) count_needed = len(instance_ids)
group.replace_autoscaling_group_instances(count_needed, group.get_propagated_tags()) group.replace_autoscaling_group_instances(
count_needed, group.get_propagated_tags()
)
self.update_attached_elbs(group_name) self.update_attached_elbs(group_name)
return detached_instances return detached_instances
@ -532,19 +749,32 @@ class AutoScalingBackend(BaseBackend):
desired_capacity = int(desired_capacity) desired_capacity = int(desired_capacity)
self.set_desired_capacity(group_name, desired_capacity) self.set_desired_capacity(group_name, desired_capacity)
def create_autoscaling_policy(self, name, policy_type, adjustment_type, as_name, def create_autoscaling_policy(
scaling_adjustment, cooldown): self, name, policy_type, adjustment_type, as_name, scaling_adjustment, cooldown
policy = FakeScalingPolicy(name, policy_type, adjustment_type, as_name, ):
scaling_adjustment, cooldown, self) policy = FakeScalingPolicy(
name,
policy_type,
adjustment_type,
as_name,
scaling_adjustment,
cooldown,
self,
)
self.policies[name] = policy self.policies[name] = policy
return policy return policy
def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None): def describe_policies(
return [policy for policy in self.policies.values() self, autoscaling_group_name=None, policy_names=None, policy_types=None
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and ):
(not policy_names or policy.name in policy_names) and return [
(not policy_types or policy.policy_type in policy_types)] policy
for policy in self.policies.values()
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name)
and (not policy_names or policy.name in policy_names)
and (not policy_types or policy.policy_type in policy_types)
]
def delete_policy(self, group_name): def delete_policy(self, group_name):
self.policies.pop(group_name, None) self.policies.pop(group_name, None)
@ -555,16 +785,14 @@ class AutoScalingBackend(BaseBackend):
def update_attached_elbs(self, group_name): def update_attached_elbs(self, group_name):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group_instance_ids = set( group_instance_ids = set(state.instance.id for state in group.instance_states)
state.instance.id for state in group.instance_states)
# skip this if group.load_balancers is empty # skip this if group.load_balancers is empty
# otherwise elb_backend.describe_load_balancers returns all available load balancers # otherwise elb_backend.describe_load_balancers returns all available load balancers
if not group.load_balancers: if not group.load_balancers:
return return
try: try:
elbs = self.elb_backend.describe_load_balancers( elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
names=group.load_balancers)
except LoadBalancerNotFoundError: except LoadBalancerNotFoundError:
# ELBs can be deleted before their autoscaling group # ELBs can be deleted before their autoscaling group
return return
@ -572,14 +800,15 @@ class AutoScalingBackend(BaseBackend):
for elb in elbs: for elb in elbs:
elb_instace_ids = set(elb.instance_ids) elb_instace_ids = set(elb.instance_ids)
self.elb_backend.register_instances( self.elb_backend.register_instances(
elb.name, group_instance_ids - elb_instace_ids) elb.name, group_instance_ids - elb_instace_ids
)
self.elb_backend.deregister_instances( self.elb_backend.deregister_instances(
elb.name, elb_instace_ids - group_instance_ids) elb.name, elb_instace_ids - group_instance_ids
)
def update_attached_target_groups(self, group_name): def update_attached_target_groups(self, group_name):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group_instance_ids = set( group_instance_ids = set(state.instance.id for state in group.instance_states)
state.instance.id for state in group.instance_states)
# no action necessary if target_group_arns is empty # no action necessary if target_group_arns is empty
if not group.target_group_arns: if not group.target_group_arns:
@ -588,10 +817,13 @@ class AutoScalingBackend(BaseBackend):
target_groups = self.elbv2_backend.describe_target_groups( target_groups = self.elbv2_backend.describe_target_groups(
target_group_arns=group.target_group_arns, target_group_arns=group.target_group_arns,
load_balancer_arn=None, load_balancer_arn=None,
names=None) names=None,
)
for target_group in target_groups: for target_group in target_groups:
asg_targets = [{'id': x, 'port': target_group.port} for x in group_instance_ids] asg_targets = [
{"id": x, "port": target_group.port} for x in group_instance_ids
]
self.elbv2_backend.register_targets(target_group.arn, (asg_targets)) self.elbv2_backend.register_targets(target_group.arn, (asg_targets))
def create_or_update_tags(self, tags): def create_or_update_tags(self, tags):
@ -609,7 +841,7 @@ class AutoScalingBackend(BaseBackend):
new_tags.append(old_tag) new_tags.append(old_tag)
# if key was never in old_tag's add it (create tag) # if key was never in old_tag's add it (create tag)
if not any(new_tag['key'] == tag['key'] for new_tag in new_tags): if not any(new_tag["key"] == tag["key"] for new_tag in new_tags):
new_tags.append(tag) new_tags.append(tag)
group.tags = new_tags group.tags = new_tags
@ -617,7 +849,8 @@ class AutoScalingBackend(BaseBackend):
def attach_load_balancers(self, group_name, load_balancer_names): def attach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group.load_balancers.extend( group.load_balancers.extend(
[x for x in load_balancer_names if x not in group.load_balancers]) [x for x in load_balancer_names if x not in group.load_balancers]
)
self.update_attached_elbs(group_name) self.update_attached_elbs(group_name)
def describe_load_balancers(self, group_name): def describe_load_balancers(self, group_name):
@ -625,13 +858,13 @@ class AutoScalingBackend(BaseBackend):
def detach_load_balancers(self, group_name, load_balancer_names): def detach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group_instance_ids = set( group_instance_ids = set(state.instance.id for state in group.instance_states)
state.instance.id for state in group.instance_states)
elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers) elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
for elb in elbs: for elb in elbs:
self.elb_backend.deregister_instances( self.elb_backend.deregister_instances(elb.name, group_instance_ids)
elb.name, group_instance_ids) group.load_balancers = [
group.load_balancers = [x for x in group.load_balancers if x not in load_balancer_names] x for x in group.load_balancers if x not in load_balancer_names
]
def attach_load_balancer_target_groups(self, group_name, target_group_arns): def attach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
@ -643,24 +876,51 @@ class AutoScalingBackend(BaseBackend):
def detach_load_balancer_target_groups(self, group_name, target_group_arns): def detach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group.target_group_arns = [x for x in group.target_group_arns if x not in target_group_arns] group.target_group_arns = [
x for x in group.target_group_arns if x not in target_group_arns
]
for target_group in target_group_arns: for target_group in target_group_arns:
asg_targets = [{'id': x.instance.id} for x in group.instance_states] asg_targets = [{"id": x.instance.id} for x in group.instance_states]
self.elbv2_backend.deregister_targets(target_group, (asg_targets)) self.elbv2_backend.deregister_targets(target_group, (asg_targets))
def suspend_processes(self, group_name, scaling_processes): def suspend_processes(self, group_name, scaling_processes):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group.suspended_processes = scaling_processes or [] group.suspended_processes = scaling_processes or []
def set_instance_protection(self, group_name, instance_ids, protected_from_scale_in): def set_instance_protection(
self, group_name, instance_ids, protected_from_scale_in
):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
protected_instances = [ protected_instances = [
x for x in group.instance_states if x.instance.id in instance_ids] x for x in group.instance_states if x.instance.id in instance_ids
]
for instance in protected_instances: for instance in protected_instances:
instance.protected_from_scale_in = protected_from_scale_in instance.protected_from_scale_in = protected_from_scale_in
def notify_terminate_instances(self, instance_ids):
for (
autoscaling_group_name,
autoscaling_group,
) in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states)
autoscaling_group.instance_states = list(
filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states,
)
)
difference = original_instance_count - len(
autoscaling_group.instance_states
)
if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(
difference, autoscaling_group.get_propagated_tags()
)
self.update_attached_elbs(autoscaling_group_name)
autoscaling_backends = {} autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():
autoscaling_backends[region] = AutoScalingBackend( autoscaling_backends[region] = AutoScalingBackend(
ec2_backend, elb_backends[region], elbv2_backends[region]) ec2_backend, elb_backends[region], elbv2_backends[region]
)

View File

@ -6,87 +6,88 @@ from .models import autoscaling_backends
class AutoScalingResponse(BaseResponse): class AutoScalingResponse(BaseResponse):
@property @property
def autoscaling_backend(self): def autoscaling_backend(self):
return autoscaling_backends[self.region] return autoscaling_backends[self.region]
def create_launch_configuration(self): def create_launch_configuration(self):
instance_monitoring_string = self._get_param( instance_monitoring_string = self._get_param("InstanceMonitoring.Enabled")
'InstanceMonitoring.Enabled') if instance_monitoring_string == "true":
if instance_monitoring_string == 'true':
instance_monitoring = True instance_monitoring = True
else: else:
instance_monitoring = False instance_monitoring = False
self.autoscaling_backend.create_launch_configuration( self.autoscaling_backend.create_launch_configuration(
name=self._get_param('LaunchConfigurationName'), name=self._get_param("LaunchConfigurationName"),
image_id=self._get_param('ImageId'), image_id=self._get_param("ImageId"),
key_name=self._get_param('KeyName'), key_name=self._get_param("KeyName"),
ramdisk_id=self._get_param('RamdiskId'), ramdisk_id=self._get_param("RamdiskId"),
kernel_id=self._get_param('KernelId'), kernel_id=self._get_param("KernelId"),
security_groups=self._get_multi_param('SecurityGroups.member'), security_groups=self._get_multi_param("SecurityGroups.member"),
user_data=self._get_param('UserData'), user_data=self._get_param("UserData"),
instance_type=self._get_param('InstanceType'), instance_type=self._get_param("InstanceType"),
instance_monitoring=instance_monitoring, instance_monitoring=instance_monitoring,
instance_profile_name=self._get_param('IamInstanceProfile'), instance_profile_name=self._get_param("IamInstanceProfile"),
spot_price=self._get_param('SpotPrice'), spot_price=self._get_param("SpotPrice"),
ebs_optimized=self._get_param('EbsOptimized'), ebs_optimized=self._get_param("EbsOptimized"),
associate_public_ip_address=self._get_param( associate_public_ip_address=self._get_param("AssociatePublicIpAddress"),
"AssociatePublicIpAddress"), block_device_mappings=self._get_list_prefix("BlockDeviceMappings.member"),
block_device_mappings=self._get_list_prefix(
'BlockDeviceMappings.member')
) )
template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE) template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render() return template.render()
def describe_launch_configurations(self): def describe_launch_configurations(self):
names = self._get_multi_param('LaunchConfigurationNames.member') names = self._get_multi_param("LaunchConfigurationNames.member")
all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(names) all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(
marker = self._get_param('NextToken') names
)
marker = self._get_param("NextToken")
all_names = [lc.name for lc in all_launch_configurations] all_names = [lc.name for lc in all_launch_configurations]
if marker: if marker:
start = all_names.index(marker) + 1 start = all_names.index(marker) + 1
else: else:
start = 0 start = 0
max_records = self._get_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier max_records = self._get_int_param(
launch_configurations_resp = all_launch_configurations[start:start + max_records] "MaxRecords", 50
) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[
start : start + max_records
]
next_token = None next_token = None
if len(all_launch_configurations) > start + max_records: if len(all_launch_configurations) > start + max_records:
next_token = launch_configurations_resp[-1].name next_token = launch_configurations_resp[-1].name
template = self.response_template( template = self.response_template(DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE)
DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE) return template.render(
return template.render(launch_configurations=launch_configurations_resp, next_token=next_token) launch_configurations=launch_configurations_resp, next_token=next_token
)
def delete_launch_configuration(self): def delete_launch_configuration(self):
launch_configurations_name = self.querystring.get( launch_configurations_name = self.querystring.get("LaunchConfigurationName")[0]
'LaunchConfigurationName')[0] self.autoscaling_backend.delete_launch_configuration(launch_configurations_name)
self.autoscaling_backend.delete_launch_configuration(
launch_configurations_name)
template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE) template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render() return template.render()
def create_auto_scaling_group(self): def create_auto_scaling_group(self):
self.autoscaling_backend.create_auto_scaling_group( self.autoscaling_backend.create_auto_scaling_group(
name=self._get_param('AutoScalingGroupName'), name=self._get_param("AutoScalingGroupName"),
availability_zones=self._get_multi_param( availability_zones=self._get_multi_param("AvailabilityZones.member"),
'AvailabilityZones.member'), desired_capacity=self._get_int_param("DesiredCapacity"),
desired_capacity=self._get_int_param('DesiredCapacity'), max_size=self._get_int_param("MaxSize"),
max_size=self._get_int_param('MaxSize'), min_size=self._get_int_param("MinSize"),
min_size=self._get_int_param('MinSize'), instance_id=self._get_param("InstanceId"),
launch_config_name=self._get_param('LaunchConfigurationName'), launch_config_name=self._get_param("LaunchConfigurationName"),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
default_cooldown=self._get_int_param('DefaultCooldown'), default_cooldown=self._get_int_param("DefaultCooldown"),
health_check_period=self._get_int_param('HealthCheckGracePeriod'), health_check_period=self._get_int_param("HealthCheckGracePeriod"),
health_check_type=self._get_param('HealthCheckType'), health_check_type=self._get_param("HealthCheckType"),
load_balancers=self._get_multi_param('LoadBalancerNames.member'), load_balancers=self._get_multi_param("LoadBalancerNames.member"),
target_group_arns=self._get_multi_param('TargetGroupARNs.member'), target_group_arns=self._get_multi_param("TargetGroupARNs.member"),
placement_group=self._get_param('PlacementGroup'), placement_group=self._get_param("PlacementGroup"),
termination_policies=self._get_multi_param( termination_policies=self._get_multi_param("TerminationPolicies.member"),
'TerminationPolicies.member'), tags=self._get_list_prefix("Tags.member"),
tags=self._get_list_prefix('Tags.member'),
new_instances_protected_from_scale_in=self._get_bool_param( new_instances_protected_from_scale_in=self._get_bool_param(
'NewInstancesProtectedFromScaleIn', False) "NewInstancesProtectedFromScaleIn", False
),
) )
template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render() return template.render()
@ -94,68 +95,73 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def attach_instances(self): def attach_instances(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param('InstanceIds.member') instance_ids = self._get_multi_param("InstanceIds.member")
self.autoscaling_backend.attach_instances( self.autoscaling_backend.attach_instances(group_name, instance_ids)
group_name, instance_ids)
template = self.response_template(ATTACH_INSTANCES_TEMPLATE) template = self.response_template(ATTACH_INSTANCES_TEMPLATE)
return template.render() return template.render()
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def set_instance_health(self): def set_instance_health(self):
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
health_status = self._get_param("HealthStatus") health_status = self._get_param("HealthStatus")
if health_status not in ['Healthy', 'Unhealthy']: if health_status not in ["Healthy", "Unhealthy"]:
raise ValueError('Valid instance health states are: [Healthy, Unhealthy]') raise ValueError("Valid instance health states are: [Healthy, Unhealthy]")
should_respect_grace_period = self._get_param("ShouldRespectGracePeriod") should_respect_grace_period = self._get_param("ShouldRespectGracePeriod")
self.autoscaling_backend.set_instance_health(instance_id, health_status, should_respect_grace_period) self.autoscaling_backend.set_instance_health(
instance_id, health_status, should_respect_grace_period
)
template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE) template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE)
return template.render() return template.render()
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def detach_instances(self): def detach_instances(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param('InstanceIds.member') instance_ids = self._get_multi_param("InstanceIds.member")
should_decrement_string = self._get_param('ShouldDecrementDesiredCapacity') should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
if should_decrement_string == 'true': if should_decrement_string == "true":
should_decrement = True should_decrement = True
else: else:
should_decrement = False should_decrement = False
detached_instances = self.autoscaling_backend.detach_instances( detached_instances = self.autoscaling_backend.detach_instances(
group_name, instance_ids, should_decrement) group_name, instance_ids, should_decrement
)
template = self.response_template(DETACH_INSTANCES_TEMPLATE) template = self.response_template(DETACH_INSTANCES_TEMPLATE)
return template.render(detached_instances=detached_instances) return template.render(detached_instances=detached_instances)
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def attach_load_balancer_target_groups(self): def attach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self._get_multi_param('TargetGroupARNs.member') target_group_arns = self._get_multi_param("TargetGroupARNs.member")
self.autoscaling_backend.attach_load_balancer_target_groups( self.autoscaling_backend.attach_load_balancer_target_groups(
group_name, target_group_arns) group_name, target_group_arns
)
template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE) template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render() return template.render()
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def describe_load_balancer_target_groups(self): def describe_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups( target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups(
group_name) group_name
)
template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS) template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS)
return template.render(target_group_arns=target_group_arns) return template.render(target_group_arns=target_group_arns)
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def detach_load_balancer_target_groups(self): def detach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self._get_multi_param('TargetGroupARNs.member') target_group_arns = self._get_multi_param("TargetGroupARNs.member")
self.autoscaling_backend.detach_load_balancer_target_groups( self.autoscaling_backend.detach_load_balancer_target_groups(
group_name, target_group_arns) group_name, target_group_arns
)
template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE) template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render() return template.render()
@ -171,7 +177,7 @@ class AutoScalingResponse(BaseResponse):
max_records = self._get_int_param("MaxRecords", 50) max_records = self._get_int_param("MaxRecords", 50)
if max_records > 100: if max_records > 100:
raise ValueError raise ValueError
groups = all_groups[start:start + max_records] groups = all_groups[start : start + max_records]
next_token = None next_token = None
if max_records and len(all_groups) > start + max_records: if max_records and len(all_groups) > start + max_records:
next_token = groups[-1].name next_token = groups[-1].name
@ -180,42 +186,40 @@ class AutoScalingResponse(BaseResponse):
def update_auto_scaling_group(self): def update_auto_scaling_group(self):
self.autoscaling_backend.update_auto_scaling_group( self.autoscaling_backend.update_auto_scaling_group(
name=self._get_param('AutoScalingGroupName'), name=self._get_param("AutoScalingGroupName"),
availability_zones=self._get_multi_param( availability_zones=self._get_multi_param("AvailabilityZones.member"),
'AvailabilityZones.member'), desired_capacity=self._get_int_param("DesiredCapacity"),
desired_capacity=self._get_int_param('DesiredCapacity'), max_size=self._get_int_param("MaxSize"),
max_size=self._get_int_param('MaxSize'), min_size=self._get_int_param("MinSize"),
min_size=self._get_int_param('MinSize'), launch_config_name=self._get_param("LaunchConfigurationName"),
launch_config_name=self._get_param('LaunchConfigurationName'), vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), default_cooldown=self._get_int_param("DefaultCooldown"),
default_cooldown=self._get_int_param('DefaultCooldown'), health_check_period=self._get_int_param("HealthCheckGracePeriod"),
health_check_period=self._get_int_param('HealthCheckGracePeriod'), health_check_type=self._get_param("HealthCheckType"),
health_check_type=self._get_param('HealthCheckType'), placement_group=self._get_param("PlacementGroup"),
placement_group=self._get_param('PlacementGroup'), termination_policies=self._get_multi_param("TerminationPolicies.member"),
termination_policies=self._get_multi_param(
'TerminationPolicies.member'),
new_instances_protected_from_scale_in=self._get_bool_param( new_instances_protected_from_scale_in=self._get_bool_param(
'NewInstancesProtectedFromScaleIn', None) "NewInstancesProtectedFromScaleIn", None
),
) )
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render() return template.render()
def delete_auto_scaling_group(self): def delete_auto_scaling_group(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
self.autoscaling_backend.delete_auto_scaling_group(group_name) self.autoscaling_backend.delete_auto_scaling_group(group_name)
template = self.response_template(DELETE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(DELETE_AUTOSCALING_GROUP_TEMPLATE)
return template.render() return template.render()
def set_desired_capacity(self): def set_desired_capacity(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
desired_capacity = self._get_int_param('DesiredCapacity') desired_capacity = self._get_int_param("DesiredCapacity")
self.autoscaling_backend.set_desired_capacity( self.autoscaling_backend.set_desired_capacity(group_name, desired_capacity)
group_name, desired_capacity)
template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE) template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE)
return template.render() return template.render()
def create_or_update_tags(self): def create_or_update_tags(self):
tags = self._get_list_prefix('Tags.member') tags = self._get_list_prefix("Tags.member")
self.autoscaling_backend.create_or_update_tags(tags) self.autoscaling_backend.create_or_update_tags(tags)
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
@ -223,38 +227,38 @@ class AutoScalingResponse(BaseResponse):
def describe_auto_scaling_instances(self): def describe_auto_scaling_instances(self):
instance_states = self.autoscaling_backend.describe_auto_scaling_instances() instance_states = self.autoscaling_backend.describe_auto_scaling_instances()
template = self.response_template( template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
return template.render(instance_states=instance_states) return template.render(instance_states=instance_states)
def put_scaling_policy(self): def put_scaling_policy(self):
policy = self.autoscaling_backend.create_autoscaling_policy( policy = self.autoscaling_backend.create_autoscaling_policy(
name=self._get_param('PolicyName'), name=self._get_param("PolicyName"),
policy_type=self._get_param('PolicyType'), policy_type=self._get_param("PolicyType"),
adjustment_type=self._get_param('AdjustmentType'), adjustment_type=self._get_param("AdjustmentType"),
as_name=self._get_param('AutoScalingGroupName'), as_name=self._get_param("AutoScalingGroupName"),
scaling_adjustment=self._get_int_param('ScalingAdjustment'), scaling_adjustment=self._get_int_param("ScalingAdjustment"),
cooldown=self._get_int_param('Cooldown'), cooldown=self._get_int_param("Cooldown"),
) )
template = self.response_template(CREATE_SCALING_POLICY_TEMPLATE) template = self.response_template(CREATE_SCALING_POLICY_TEMPLATE)
return template.render(policy=policy) return template.render(policy=policy)
def describe_policies(self): def describe_policies(self):
policies = self.autoscaling_backend.describe_policies( policies = self.autoscaling_backend.describe_policies(
autoscaling_group_name=self._get_param('AutoScalingGroupName'), autoscaling_group_name=self._get_param("AutoScalingGroupName"),
policy_names=self._get_multi_param('PolicyNames.member'), policy_names=self._get_multi_param("PolicyNames.member"),
policy_types=self._get_multi_param('PolicyTypes.member')) policy_types=self._get_multi_param("PolicyTypes.member"),
)
template = self.response_template(DESCRIBE_SCALING_POLICIES_TEMPLATE) template = self.response_template(DESCRIBE_SCALING_POLICIES_TEMPLATE)
return template.render(policies=policies) return template.render(policies=policies)
def delete_policy(self): def delete_policy(self):
group_name = self._get_param('PolicyName') group_name = self._get_param("PolicyName")
self.autoscaling_backend.delete_policy(group_name) self.autoscaling_backend.delete_policy(group_name)
template = self.response_template(DELETE_POLICY_TEMPLATE) template = self.response_template(DELETE_POLICY_TEMPLATE)
return template.render() return template.render()
def execute_policy(self): def execute_policy(self):
group_name = self._get_param('PolicyName') group_name = self._get_param("PolicyName")
self.autoscaling_backend.execute_policy(group_name) self.autoscaling_backend.execute_policy(group_name)
template = self.response_template(EXECUTE_POLICY_TEMPLATE) template = self.response_template(EXECUTE_POLICY_TEMPLATE)
return template.render() return template.render()
@ -262,17 +266,16 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def attach_load_balancers(self): def attach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
load_balancer_names = self._get_multi_param("LoadBalancerNames.member") load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.attach_load_balancers( self.autoscaling_backend.attach_load_balancers(group_name, load_balancer_names)
group_name, load_balancer_names)
template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE) template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE)
return template.render() return template.render()
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def describe_load_balancers(self): def describe_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
load_balancers = self.autoscaling_backend.describe_load_balancers(group_name) load_balancers = self.autoscaling_backend.describe_load_balancers(group_name)
template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE) template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE)
return template.render(load_balancers=load_balancers) return template.render(load_balancers=load_balancers)
@ -280,26 +283,28 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def detach_load_balancers(self): def detach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
load_balancer_names = self._get_multi_param("LoadBalancerNames.member") load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.detach_load_balancers( self.autoscaling_backend.detach_load_balancers(group_name, load_balancer_names)
group_name, load_balancer_names)
template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE) template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE)
return template.render() return template.render()
def suspend_processes(self): def suspend_processes(self):
autoscaling_group_name = self._get_param('AutoScalingGroupName') autoscaling_group_name = self._get_param("AutoScalingGroupName")
scaling_processes = self._get_multi_param('ScalingProcesses.member') scaling_processes = self._get_multi_param("ScalingProcesses.member")
self.autoscaling_backend.suspend_processes(autoscaling_group_name, scaling_processes) self.autoscaling_backend.suspend_processes(
autoscaling_group_name, scaling_processes
)
template = self.response_template(SUSPEND_PROCESSES_TEMPLATE) template = self.response_template(SUSPEND_PROCESSES_TEMPLATE)
return template.render() return template.render()
def set_instance_protection(self): def set_instance_protection(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param('InstanceIds.member') instance_ids = self._get_multi_param("InstanceIds.member")
protected_from_scale_in = self._get_bool_param('ProtectedFromScaleIn') protected_from_scale_in = self._get_bool_param("ProtectedFromScaleIn")
self.autoscaling_backend.set_instance_protection( self.autoscaling_backend.set_instance_protection(
group_name, instance_ids, protected_from_scale_in) group_name, instance_ids, protected_from_scale_in
)
template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE) template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE)
return template.render() return template.render()
@ -499,7 +504,7 @@ DESCRIBE_AUTOSCALING_GROUPS_TEMPLATE = """<DescribeAutoScalingGroupsResponse xml
{% for instance_state in group.instance_states %} {% for instance_state in group.instance_states %}
<member> <member>
<HealthStatus>{{ instance_state.health_status }}</HealthStatus> <HealthStatus>{{ instance_state.health_status }}</HealthStatus>
<AvailabilityZone>us-east-1e</AvailabilityZone> <AvailabilityZone>{{ instance_state.instance.placement }}</AvailabilityZone>
<InstanceId>{{ instance_state.instance.id }}</InstanceId> <InstanceId>{{ instance_state.instance.id }}</InstanceId>
<LaunchConfigurationName>{{ group.launch_config_name }}</LaunchConfigurationName> <LaunchConfigurationName>{{ group.launch_config_name }}</LaunchConfigurationName>
<LifecycleState>{{ instance_state.lifecycle_state }}</LifecycleState> <LifecycleState>{{ instance_state.lifecycle_state }}</LifecycleState>
@ -585,7 +590,7 @@ DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE = """<DescribeAutoScalingInstancesRespon
<member> <member>
<HealthStatus>{{ instance_state.health_status }}</HealthStatus> <HealthStatus>{{ instance_state.health_status }}</HealthStatus>
<AutoScalingGroupName>{{ instance_state.instance.autoscaling_group.name }}</AutoScalingGroupName> <AutoScalingGroupName>{{ instance_state.instance.autoscaling_group.name }}</AutoScalingGroupName>
<AvailabilityZone>us-east-1e</AvailabilityZone> <AvailabilityZone>{{ instance_state.instance.placement }}</AvailabilityZone>
<InstanceId>{{ instance_state.instance.id }}</InstanceId> <InstanceId>{{ instance_state.instance.id }}</InstanceId>
<LaunchConfigurationName>{{ instance_state.instance.autoscaling_group.launch_config_name }}</LaunchConfigurationName> <LaunchConfigurationName>{{ instance_state.instance.autoscaling_group.launch_config_name }}</LaunchConfigurationName>
<LifecycleState>{{ instance_state.lifecycle_state }}</LifecycleState> <LifecycleState>{{ instance_state.lifecycle_state }}</LifecycleState>

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import AutoScalingResponse from .responses import AutoScalingResponse
url_bases = [ url_bases = ["https?://autoscaling.(.+).amazonaws.com"]
"https?://autoscaling.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": AutoScalingResponse.dispatch}
'{0}/$': AutoScalingResponse.dispatch,
}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import lambda_backends from .models import lambda_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
lambda_backend = lambda_backends['us-east-1'] lambda_backend = lambda_backends["us-east-1"]
mock_lambda = base_decorator(lambda_backends) mock_lambda = base_decorator(lambda_backends)
mock_lambda_deprecated = deprecated_base_decorator(lambda_backends) mock_lambda_deprecated = deprecated_base_decorator(lambda_backends)

View File

@ -0,0 +1,41 @@
from botocore.client import ClientError
from moto.core.exceptions import JsonRESTError
class LambdaClientError(ClientError):
def __init__(self, error, message):
error_response = {"Error": {"Code": error, "Message": message}}
super(LambdaClientError, self).__init__(error_response, None)
class CrossAccountNotAllowed(LambdaClientError):
def __init__(self):
super(CrossAccountNotAllowed, self).__init__(
"AccessDeniedException", "Cross-account pass role is not allowed."
)
class InvalidParameterValueException(LambdaClientError):
def __init__(self, message):
super(InvalidParameterValueException, self).__init__(
"InvalidParameterValueException", message
)
class InvalidRoleFormat(LambdaClientError):
pattern = r"arn:(aws[a-zA-Z-]*)?:iam::(\d{12}):role/?[a-zA-Z_0-9+=,.@\-_/]+"
def __init__(self, role):
message = "1 validation error detected: Value '{0}' at 'role' failed to satisfy constraint: Member must satisfy regular expression pattern: {1}".format(
role, InvalidRoleFormat.pattern
)
super(InvalidRoleFormat, self).__init__("ValidationException", message)
class PreconditionFailedException(JsonRESTError):
code = 412
def __init__(self, message):
super(PreconditionFailedException, self).__init__(
"PreconditionFailedException", message
)

File diff suppressed because it is too large Load Diff

134
moto/awslambda/policy.py Normal file
View File

@ -0,0 +1,134 @@
from __future__ import unicode_literals
import json
import uuid
from six import string_types
from moto.awslambda.exceptions import PreconditionFailedException
class Policy:
def __init__(self, parent):
self.revision = str(uuid.uuid4())
self.statements = []
self.parent = parent
def wire_format(self):
p = self.get_policy()
p["Policy"] = json.dumps(p["Policy"])
return json.dumps(p)
def get_policy(self):
return {
"Policy": {
"Version": "2012-10-17",
"Id": "default",
"Statement": self.statements,
},
"RevisionId": self.revision,
}
# adds the raw JSON statement to the policy
def add_statement(self, raw):
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
self.statements.append(policy.statements[0])
self.revision = str(uuid.uuid4())
# removes the statement that matches 'sid' from the policy
def del_statement(self, sid, revision=""):
if len(revision) > 0 and self.revision != revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
for statement in self.statements:
if "Sid" in statement and statement["Sid"] == sid:
self.statements.remove(statement)
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj):
# import pydevd
# pydevd.settrace("localhost", port=5678)
policy = Policy(self.parent)
policy.revision = obj.get("RevisionId", "")
# set some default values if these keys are not set
self.ensure_set(obj, "Effect", "Allow")
self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST")
self.ensure_set(obj, "StatementId", str(uuid.uuid4()))
# transform field names and values
self.transform_property(obj, "StatementId", "Sid", self.nop_formatter)
self.transform_property(obj, "Principal", "Principal", self.principal_formatter)
self.transform_property(
obj, "SourceArn", "SourceArn", self.source_arn_formatter
)
self.transform_property(
obj, "SourceAccount", "SourceAccount", self.source_account_formatter
)
# remove RevisionId and EventSourceToken if they are set
self.remove_if_set(obj, ["RevisionId", "EventSourceToken"])
# merge conditional statements into a single map under the Condition key
self.condition_merge(obj)
# append resulting statement to policy.statements
policy.statements.append(obj)
return policy
def nop_formatter(self, obj):
return obj
def ensure_set(self, obj, key, value):
if key not in obj:
obj[key] = value
def principal_formatter(self, obj):
if isinstance(obj, string_types):
if obj.endswith(".amazonaws.com"):
return {"Service": obj}
if obj.endswith(":root"):
return {"AWS": obj}
return obj
def source_account_formatter(self, obj):
return {"StringEquals": {"AWS:SourceAccount": obj}}
def source_arn_formatter(self, obj):
return {"ArnLike": {"AWS:SourceArn": obj}}
def transform_property(self, obj, old_name, new_name, formatter):
if old_name in obj:
obj[new_name] = formatter(obj[old_name])
if new_name != old_name:
del obj[old_name]
def remove_if_set(self, obj, keys):
for key in keys:
if key in obj:
del obj[key]
def condition_merge(self, obj):
if "SourceArn" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceArn"])
del obj["SourceArn"]
if "SourceAccount" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceAccount"])
del obj["SourceAccount"]

View File

@ -32,32 +32,57 @@ class LambdaResponse(BaseResponse):
def root(self, request, full_url, headers): def root(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
return self._list_functions(request, full_url, headers) return self._list_functions(request, full_url, headers)
elif request.method == 'POST': elif request.method == "POST":
return self._create_function(request, full_url, headers) return self._create_function(request, full_url, headers)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def event_source_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":
querystring = self.querystring
event_source_arn = querystring.get("EventSourceArn", [None])[0]
function_name = querystring.get("FunctionName", [None])[0]
return self._list_event_source_mappings(event_source_arn, function_name)
elif request.method == "POST":
return self._create_event_source_mapping(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def event_source_mapping(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
path = request.path if hasattr(request, "path") else path_url(request.url)
uuid = path.split("/")[-1]
if request.method == "GET":
return self._get_event_source_mapping(uuid)
elif request.method == "PUT":
return self._update_event_source_mapping(uuid)
elif request.method == "DELETE":
return self._delete_event_source_mapping(uuid)
else:
raise ValueError("Cannot handle request")
def function(self, request, full_url, headers): def function(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
return self._get_function(request, full_url, headers) return self._get_function(request, full_url, headers)
elif request.method == 'DELETE': elif request.method == "DELETE":
return self._delete_function(request, full_url, headers) return self._delete_function(request, full_url, headers)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def versions(self, request, full_url, headers): def versions(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
# This is ListVersionByFunction # This is ListVersionByFunction
path = request.path if hasattr(request, 'path') else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split('/')[-2] function_name = path.split("/")[-2]
return self._list_versions_by_function(function_name) return self._list_versions_by_function(function_name)
elif request.method == 'POST': elif request.method == "POST":
return self._publish_function(request, full_url, headers) return self._publish_function(request, full_url, headers)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
@ -66,7 +91,7 @@ class LambdaResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def invoke(self, request, full_url, headers): def invoke(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'POST': if request.method == "POST":
return self._invoke(request, full_url) return self._invoke(request, full_url)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
@ -75,57 +100,89 @@ class LambdaResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def invoke_async(self, request, full_url, headers): def invoke_async(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'POST': if request.method == "POST":
return self._invoke_async(request, full_url) return self._invoke_async(request, full_url)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def tag(self, request, full_url, headers): def tag(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
return self._list_tags(request, full_url) return self._list_tags(request, full_url)
elif request.method == 'POST': elif request.method == "POST":
return self._tag_resource(request, full_url) return self._tag_resource(request, full_url)
elif request.method == 'DELETE': elif request.method == "DELETE":
return self._untag_resource(request, full_url) return self._untag_resource(request, full_url)
else: else:
raise ValueError("Cannot handle {0} request".format(request.method)) raise ValueError("Cannot handle {0} request".format(request.method))
def policy(self, request, full_url, headers): def policy(self, request, full_url, headers):
if request.method == 'GET': self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._get_policy(request, full_url, headers) return self._get_policy(request, full_url, headers)
if request.method == 'POST': elif request.method == "POST":
return self._add_policy(request, full_url, headers) return self._add_policy(request, full_url, headers)
elif request.method == "DELETE":
return self._del_policy(request, full_url, headers, self.querystring)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self._put_configuration(request)
else:
raise ValueError("Cannot handle request")
def code(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self._put_code()
else:
raise ValueError("Cannot handle request")
def _add_policy(self, request, full_url, headers): def _add_policy(self, request, full_url, headers):
path = request.path if hasattr(request, 'path') else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split('/')[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
policy = request.body.decode('utf8') statement = self.body
self.lambda_backend.add_policy(function_name, policy) self.lambda_backend.add_policy_statement(function_name, statement)
return 200, {}, json.dumps(dict(Statement=policy)) return 200, {}, json.dumps({"Statement": statement})
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _get_policy(self, request, full_url, headers): def _get_policy(self, request, full_url, headers):
path = request.path if hasattr(request, 'path') else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split('/')[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
lambda_function = self.lambda_backend.get_function(function_name) out = self.lambda_backend.get_policy_wire_format(function_name)
return 200, {}, json.dumps(dict(Policy="{\"Statement\":[" + lambda_function.policy + "]}")) return 200, {}, out
else:
return 404, {}, "{}"
def _del_policy(self, request, full_url, headers, querystring):
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-3]
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
)
return 204, {}, "{}"
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _invoke(self, request, full_url): def _invoke(self, request, full_url):
response_headers = {} response_headers = {}
function_name = self.path.rsplit('/', 2)[-2] function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param('qualifier') qualifier = self._get_param("qualifier")
fn = self.lambda_backend.get_function(function_name, qualifier) response_header, payload = self.lambda_backend.invoke(
if fn: function_name, qualifier, self.body, self.headers, response_headers
payload = fn.invoke(self.body, self.headers, response_headers) )
response_headers['Content-Length'] = str(len(payload)) if payload:
return 202, response_headers, payload return 202, response_headers, payload
else: else:
return 404, response_headers, "{}" return 404, response_headers, "{}"
@ -133,52 +190,79 @@ class LambdaResponse(BaseResponse):
def _invoke_async(self, request, full_url): def _invoke_async(self, request, full_url):
response_headers = {} response_headers = {}
function_name = self.path.rsplit('/', 3)[-3] function_name = self.path.rsplit("/", 3)[-3]
fn = self.lambda_backend.get_function(function_name, None) fn = self.lambda_backend.get_function(function_name, None)
if fn: if fn:
payload = fn.invoke(self.body, self.headers, response_headers) payload = fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(len(payload)) response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload return 202, response_headers, payload
else: else:
return 404, response_headers, "{}" return 404, response_headers, "{}"
def _list_functions(self, request, full_url, headers): def _list_functions(self, request, full_url, headers):
result = { result = {"Functions": []}
'Functions': []
}
for fn in self.lambda_backend.list_functions(): for fn in self.lambda_backend.list_functions():
json_data = fn.get_configuration() json_data = fn.get_configuration()
json_data["Version"] = "$LATEST"
result['Functions'].append(json_data) result["Functions"].append(json_data)
return 200, {}, json.dumps(result) return 200, {}, json.dumps(result)
def _list_versions_by_function(self, function_name): def _list_versions_by_function(self, function_name):
result = { result = {"Versions": []}
'Versions': []
}
functions = self.lambda_backend.list_versions_by_function(function_name) functions = self.lambda_backend.list_versions_by_function(function_name)
if functions: if functions:
for fn in functions: for fn in functions:
json_data = fn.get_configuration() json_data = fn.get_configuration()
result['Versions'].append(json_data) result["Versions"].append(json_data)
return 200, {}, json.dumps(result) return 200, {}, json.dumps(result)
def _create_function(self, request, full_url, headers): def _create_function(self, request, full_url, headers):
try: fn = self.lambda_backend.create_function(self.json_body)
fn = self.lambda_backend.create_function(self.json_body) config = fn.get_configuration()
except ValueError as e: return 201, {}, json.dumps(config)
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}})
def _create_event_source_mapping(self, request, full_url, headers):
fn = self.lambda_backend.create_event_source_mapping(self.json_body)
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _list_event_source_mappings(self, event_source_arn, function_name):
esms = self.lambda_backend.list_event_source_mappings(
event_source_arn, function_name
)
result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]}
return 200, {}, json.dumps(result)
def _get_event_source_mapping(self, uuid):
result = self.lambda_backend.get_event_source_mapping(uuid)
if result:
return 200, {}, json.dumps(result.get_configuration())
else: else:
config = fn.get_configuration() return 404, {}, "{}"
return 201, {}, json.dumps(config)
def _update_event_source_mapping(self, uuid):
result = self.lambda_backend.update_event_source_mapping(uuid, self.json_body)
if result:
return 202, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"
def _delete_event_source_mapping(self, uuid):
esm = self.lambda_backend.delete_event_source_mapping(uuid)
if esm:
json_result = esm.get_configuration()
json_result.update({"State": "Deleting"})
return 202, {}, json.dumps(json_result)
else:
return 404, {}, "{}"
def _publish_function(self, request, full_url, headers): def _publish_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 2)[-2] function_name = self.path.rsplit("/", 2)[-2]
fn = self.lambda_backend.publish_function(function_name) fn = self.lambda_backend.publish_function(function_name)
if fn: if fn:
@ -188,8 +272,8 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}" return 404, {}, "{}"
def _delete_function(self, request, full_url, headers): def _delete_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 1)[-1] function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param('Qualifier', None) qualifier = self._get_param("Qualifier", None)
if self.lambda_backend.delete_function(function_name, qualifier): if self.lambda_backend.delete_function(function_name, qualifier):
return 204, {}, "" return 204, {}, ""
@ -197,14 +281,17 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}" return 404, {}, "{}"
def _get_function(self, request, full_url, headers): def _get_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 1)[-1] function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param('Qualifier', None) qualifier = self._get_param("Qualifier", None)
fn = self.lambda_backend.get_function(function_name, qualifier) fn = self.lambda_backend.get_function(function_name, qualifier)
if fn: if fn:
code = fn.get_code() code = fn.get_code()
if qualifier is None or qualifier == "$LATEST":
code["Configuration"]["Version"] = "$LATEST"
if qualifier == "$LATEST":
code["Configuration"]["FunctionArn"] += ":$LATEST"
return 200, {}, json.dumps(code) return 200, {}, json.dumps(code)
else: else:
return 404, {}, "{}" return 404, {}, "{}"
@ -217,27 +304,51 @@ class LambdaResponse(BaseResponse):
return self.default_region return self.default_region
def _list_tags(self, request, full_url): def _list_tags(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
fn = self.lambda_backend.get_function_by_arn(function_arn) fn = self.lambda_backend.get_function_by_arn(function_arn)
if fn: if fn:
return 200, {}, json.dumps({'Tags': fn.tags}) return 200, {}, json.dumps({"Tags": fn.tags})
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _tag_resource(self, request, full_url): def _tag_resource(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
if self.lambda_backend.tag_resource(function_arn, self.json_body['Tags']): if self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]):
return 200, {}, "{}" return 200, {}, "{}"
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _untag_resource(self, request, full_url): def _untag_resource(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
tag_keys = self.querystring['tagKeys'] tag_keys = self.querystring["tagKeys"]
if self.lambda_backend.untag_resource(function_arn, tag_keys): if self.lambda_backend.untag_resource(function_arn, tag_keys):
return 204, {}, "{}" return 204, {}, "{}"
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _put_configuration(self, request):
function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param("Qualifier", None)
resp = self.lambda_backend.update_function_configuration(
function_name, qualifier, body=self.json_body
)
if resp:
return 200, {}, json.dumps(resp)
else:
return 404, {}, "{}"
def _put_code(self):
function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param("Qualifier", None)
resp = self.lambda_backend.update_function_code(
function_name, qualifier, body=self.json_body
)
if resp:
return 200, {}, json.dumps(resp)
else:
return 404, {}, "{}"

View File

@ -1,18 +1,21 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import LambdaResponse from .responses import LambdaResponse
url_bases = [ url_bases = ["https?://lambda.(.+).amazonaws.com"]
"https?://lambda.(.+).amazonaws.com",
]
response = LambdaResponse() response = LambdaResponse()
url_paths = { url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root, r"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$': response.versions, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$": response.event_source_mapping,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$': response.policy r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/(?P<statement_id>[\w_-]+)$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,
} }

View File

@ -1,15 +1,20 @@
from collections import namedtuple from collections import namedtuple
ARN = namedtuple('ARN', ['region', 'account', 'function_name', 'version']) ARN = namedtuple("ARN", ["region", "account", "function_name", "version"])
def make_function_arn(region, account, name, version='1'): def make_function_arn(region, account, name):
return 'arn:aws:lambda:{0}:{1}:function:{2}:{3}'.format(region, account, name, version) return "arn:aws:lambda:{0}:{1}:function:{2}".format(region, account, name)
def make_function_ver_arn(region, account, name, version="1"):
arn = make_function_arn(region, account, name)
return "{0}:{1}".format(arn, version)
def split_function_arn(arn): def split_function_arn(arn):
arn = arn.replace('arn:aws:lambda:') arn = arn.replace("arn:aws:lambda:")
region, account, _, name, version = arn.split(':') region, account, _, name, version = arn.split(":")
return ARN(region, account, name, version) return ARN(region, account, name, version)

View File

@ -2,18 +2,25 @@ from __future__ import unicode_literals
from moto.acm import acm_backends from moto.acm import acm_backends
from moto.apigateway import apigateway_backends from moto.apigateway import apigateway_backends
from moto.athena import athena_backends
from moto.autoscaling import autoscaling_backends from moto.autoscaling import autoscaling_backends
from moto.awslambda import lambda_backends from moto.awslambda import lambda_backends
from moto.batch import batch_backends
from moto.cloudformation import cloudformation_backends from moto.cloudformation import cloudformation_backends
from moto.cloudwatch import cloudwatch_backends from moto.cloudwatch import cloudwatch_backends
from moto.codecommit import codecommit_backends
from moto.codepipeline import codepipeline_backends
from moto.cognitoidentity import cognitoidentity_backends from moto.cognitoidentity import cognitoidentity_backends
from moto.cognitoidp import cognitoidp_backends from moto.cognitoidp import cognitoidp_backends
from moto.config import config_backends
from moto.core import moto_api_backends from moto.core import moto_api_backends
from moto.datapipeline import datapipeline_backends from moto.datapipeline import datapipeline_backends
from moto.datasync import datasync_backends
from moto.dynamodb import dynamodb_backends from moto.dynamodb import dynamodb_backends
from moto.dynamodb2 import dynamodb_backends2 from moto.dynamodb2 import dynamodb_backends2
from moto.dynamodbstreams import dynamodbstreams_backends from moto.dynamodbstreams import dynamodbstreams_backends
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
from moto.ec2_instance_connect import ec2_instance_connect_backends
from moto.ecr import ecr_backends from moto.ecr import ecr_backends
from moto.ecs import ecs_backends from moto.ecs import ecs_backends
from moto.elb import elb_backends from moto.elb import elb_backends
@ -24,6 +31,8 @@ from moto.glacier import glacier_backends
from moto.glue import glue_backends from moto.glue import glue_backends
from moto.iam import iam_backends from moto.iam import iam_backends
from moto.instance_metadata import instance_metadata_backends from moto.instance_metadata import instance_metadata_backends
from moto.iot import iot_backends
from moto.iotdata import iotdata_backends
from moto.kinesis import kinesis_backends from moto.kinesis import kinesis_backends
from moto.kms import kms_backends from moto.kms import kms_backends
from moto.logs import logs_backends from moto.logs import logs_backends
@ -32,71 +41,76 @@ from moto.organizations import organizations_backends
from moto.polly import polly_backends from moto.polly import polly_backends
from moto.rds2 import rds2_backends from moto.rds2 import rds2_backends
from moto.redshift import redshift_backends from moto.redshift import redshift_backends
from moto.resourcegroups import resourcegroups_backends
from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends
from moto.route53 import route53_backends from moto.route53 import route53_backends
from moto.s3 import s3_backends from moto.s3 import s3_backends
from moto.ses import ses_backends
from moto.secretsmanager import secretsmanager_backends from moto.secretsmanager import secretsmanager_backends
from moto.ses import ses_backends
from moto.sns import sns_backends from moto.sns import sns_backends
from moto.sqs import sqs_backends from moto.sqs import sqs_backends
from moto.ssm import ssm_backends from moto.ssm import ssm_backends
from moto.stepfunctions import stepfunction_backends
from moto.sts import sts_backends from moto.sts import sts_backends
from moto.swf import swf_backends from moto.swf import swf_backends
from moto.xray import xray_backends from moto.xray import xray_backends
from moto.iot import iot_backends
from moto.iotdata import iotdata_backends
from moto.batch import batch_backends
from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends
from moto.config import config_backends
BACKENDS = { BACKENDS = {
'acm': acm_backends, "acm": acm_backends,
'apigateway': apigateway_backends, "apigateway": apigateway_backends,
'autoscaling': autoscaling_backends, "athena": athena_backends,
'batch': batch_backends, "autoscaling": autoscaling_backends,
'cloudformation': cloudformation_backends, "batch": batch_backends,
'cloudwatch': cloudwatch_backends, "cloudformation": cloudformation_backends,
'cognito-identity': cognitoidentity_backends, "cloudwatch": cloudwatch_backends,
'cognito-idp': cognitoidp_backends, "codecommit": codecommit_backends,
'config': config_backends, "codepipeline": codepipeline_backends,
'datapipeline': datapipeline_backends, "cognito-identity": cognitoidentity_backends,
'dynamodb': dynamodb_backends, "cognito-idp": cognitoidp_backends,
'dynamodb2': dynamodb_backends2, "config": config_backends,
'dynamodbstreams': dynamodbstreams_backends, "datapipeline": datapipeline_backends,
'ec2': ec2_backends, "datasync": datasync_backends,
'ecr': ecr_backends, "dynamodb": dynamodb_backends,
'ecs': ecs_backends, "dynamodb2": dynamodb_backends2,
'elb': elb_backends, "dynamodbstreams": dynamodbstreams_backends,
'elbv2': elbv2_backends, "ec2": ec2_backends,
'events': events_backends, "ec2_instance_connect": ec2_instance_connect_backends,
'emr': emr_backends, "ecr": ecr_backends,
'glacier': glacier_backends, "ecs": ecs_backends,
'glue': glue_backends, "elb": elb_backends,
'iam': iam_backends, "elbv2": elbv2_backends,
'moto_api': moto_api_backends, "events": events_backends,
'instance_metadata': instance_metadata_backends, "emr": emr_backends,
'logs': logs_backends, "glacier": glacier_backends,
'kinesis': kinesis_backends, "glue": glue_backends,
'kms': kms_backends, "iam": iam_backends,
'opsworks': opsworks_backends, "moto_api": moto_api_backends,
'organizations': organizations_backends, "instance_metadata": instance_metadata_backends,
'polly': polly_backends, "logs": logs_backends,
'redshift': redshift_backends, "kinesis": kinesis_backends,
'rds': rds2_backends, "kms": kms_backends,
's3': s3_backends, "opsworks": opsworks_backends,
's3bucket_path': s3_backends, "organizations": organizations_backends,
'ses': ses_backends, "polly": polly_backends,
'secretsmanager': secretsmanager_backends, "redshift": redshift_backends,
'sns': sns_backends, "resource-groups": resourcegroups_backends,
'sqs': sqs_backends, "rds": rds2_backends,
'ssm': ssm_backends, "s3": s3_backends,
'sts': sts_backends, "s3bucket_path": s3_backends,
'swf': swf_backends, "ses": ses_backends,
'route53': route53_backends, "secretsmanager": secretsmanager_backends,
'lambda': lambda_backends, "sns": sns_backends,
'xray': xray_backends, "sqs": sqs_backends,
'resourcegroupstaggingapi': resourcegroupstaggingapi_backends, "ssm": ssm_backends,
'iot': iot_backends, "stepfunctions": stepfunction_backends,
'iot-data': iotdata_backends, "sts": sts_backends,
"swf": swf_backends,
"route53": route53_backends,
"lambda": lambda_backends,
"xray": xray_backends,
"resourcegroupstaggingapi": resourcegroupstaggingapi_backends,
"iot": iot_backends,
"iot-data": iotdata_backends,
} }
@ -104,6 +118,6 @@ def get_model(name, region_name):
for backends in BACKENDS.values(): for backends in BACKENDS.values():
for region, backend in backends.items(): for region, backend in backends.items():
if region == region_name: if region == region_name:
models = getattr(backend.__class__, '__models__', {}) models = getattr(backend.__class__, "__models__", {})
if name in models: if name in models:
return list(getattr(backend, models[name])()) return list(getattr(backend, models[name])())

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import batch_backends from .models import batch_backends
from ..core.models import base_decorator from ..core.models import base_decorator
batch_backend = batch_backends['us-east-1'] batch_backend = batch_backends["us-east-1"]
mock_batch = base_decorator(batch_backends) mock_batch = base_decorator(batch_backends)

View File

@ -12,26 +12,29 @@ class AWSError(Exception):
self.status = status if status is not None else self.STATUS self.status = status if status is not None else self.STATUS
def response(self): def response(self):
return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) return (
json.dumps({"__type": self.code, "message": self.message}),
dict(status=self.status),
)
class InvalidRequestException(AWSError): class InvalidRequestException(AWSError):
CODE = 'InvalidRequestException' CODE = "InvalidRequestException"
class InvalidParameterValueException(AWSError): class InvalidParameterValueException(AWSError):
CODE = 'InvalidParameterValue' CODE = "InvalidParameterValue"
class ValidationError(AWSError): class ValidationError(AWSError):
CODE = 'ValidationError' CODE = "ValidationError"
class InternalFailure(AWSError): class InternalFailure(AWSError):
CODE = 'InternalFailure' CODE = "InternalFailure"
STATUS = 500 STATUS = 500
class ClientException(AWSError): class ClientException(AWSError):
CODE = 'ClientException' CODE = "ClientException"
STATUS = 400 STATUS = 400

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ import json
class BatchResponse(BaseResponse): class BatchResponse(BaseResponse):
def _error(self, code, message): def _error(self, code, message):
return json.dumps({'__type': code, 'message': message}), dict(status=400) return json.dumps({"__type": code, "message": message}), dict(status=400)
@property @property
def batch_backend(self): def batch_backend(self):
@ -22,9 +22,9 @@ class BatchResponse(BaseResponse):
@property @property
def json(self): def json(self):
if self.body is None or self.body == '': if self.body is None or self.body == "":
self._json = {} self._json = {}
elif not hasattr(self, '_json'): elif not hasattr(self, "_json"):
try: try:
self._json = json.loads(self.body) self._json = json.loads(self.body)
except ValueError: except ValueError:
@ -39,153 +39,146 @@ class BatchResponse(BaseResponse):
def _get_action(self): def _get_action(self):
# Return element after the /v1/* # Return element after the /v1/*
return urlsplit(self.uri).path.lstrip('/').split('/')[1] return urlsplit(self.uri).path.lstrip("/").split("/")[1]
# CreateComputeEnvironment # CreateComputeEnvironment
def createcomputeenvironment(self): def createcomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironmentName') compute_env_name = self._get_param("computeEnvironmentName")
compute_resource = self._get_param('computeResources') compute_resource = self._get_param("computeResources")
service_role = self._get_param('serviceRole') service_role = self._get_param("serviceRole")
state = self._get_param('state') state = self._get_param("state")
_type = self._get_param('type') _type = self._get_param("type")
try: try:
name, arn = self.batch_backend.create_compute_environment( name, arn = self.batch_backend.create_compute_environment(
compute_environment_name=compute_env_name, compute_environment_name=compute_env_name,
_type=_type, state=state, _type=_type,
state=state,
compute_resources=compute_resource, compute_resources=compute_resource,
service_role=service_role service_role=service_role,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
return json.dumps(result) return json.dumps(result)
# DescribeComputeEnvironments # DescribeComputeEnvironments
def describecomputeenvironments(self): def describecomputeenvironments(self):
compute_environments = self._get_param('computeEnvironments') compute_environments = self._get_param("computeEnvironments")
max_results = self._get_param('maxResults') # Ignored, should be int max_results = self._get_param("maxResults") # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored next_token = self._get_param("nextToken") # Ignored
envs = self.batch_backend.describe_compute_environments(compute_environments, max_results=max_results, next_token=next_token) envs = self.batch_backend.describe_compute_environments(
compute_environments, max_results=max_results, next_token=next_token
)
result = {'computeEnvironments': envs} result = {"computeEnvironments": envs}
return json.dumps(result) return json.dumps(result)
# DeleteComputeEnvironment # DeleteComputeEnvironment
def deletecomputeenvironment(self): def deletecomputeenvironment(self):
compute_environment = self._get_param('computeEnvironment') compute_environment = self._get_param("computeEnvironment")
try: try:
self.batch_backend.delete_compute_environment(compute_environment) self.batch_backend.delete_compute_environment(compute_environment)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
# UpdateComputeEnvironment # UpdateComputeEnvironment
def updatecomputeenvironment(self): def updatecomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironment') compute_env_name = self._get_param("computeEnvironment")
compute_resource = self._get_param('computeResources') compute_resource = self._get_param("computeResources")
service_role = self._get_param('serviceRole') service_role = self._get_param("serviceRole")
state = self._get_param('state') state = self._get_param("state")
try: try:
name, arn = self.batch_backend.update_compute_environment( name, arn = self.batch_backend.update_compute_environment(
compute_environment_name=compute_env_name, compute_environment_name=compute_env_name,
compute_resources=compute_resource, compute_resources=compute_resource,
service_role=service_role, service_role=service_role,
state=state state=state,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
return json.dumps(result) return json.dumps(result)
# CreateJobQueue # CreateJobQueue
def createjobqueue(self): def createjobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder') compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param('jobQueueName') queue_name = self._get_param("jobQueueName")
priority = self._get_param('priority') priority = self._get_param("priority")
state = self._get_param('state') state = self._get_param("state")
try: try:
name, arn = self.batch_backend.create_job_queue( name, arn = self.batch_backend.create_job_queue(
queue_name=queue_name, queue_name=queue_name,
priority=priority, priority=priority,
state=state, state=state,
compute_env_order=compute_env_order compute_env_order=compute_env_order,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"jobQueueArn": arn, "jobQueueName": name}
'jobQueueArn': arn,
'jobQueueName': name
}
return json.dumps(result) return json.dumps(result)
# DescribeJobQueues # DescribeJobQueues
def describejobqueues(self): def describejobqueues(self):
job_queues = self._get_param('jobQueues') job_queues = self._get_param("jobQueues")
max_results = self._get_param('maxResults') # Ignored, should be int max_results = self._get_param("maxResults") # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored next_token = self._get_param("nextToken") # Ignored
queues = self.batch_backend.describe_job_queues(job_queues, max_results=max_results, next_token=next_token) queues = self.batch_backend.describe_job_queues(
job_queues, max_results=max_results, next_token=next_token
)
result = {'jobQueues': queues} result = {"jobQueues": queues}
return json.dumps(result) return json.dumps(result)
# UpdateJobQueue # UpdateJobQueue
def updatejobqueue(self): def updatejobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder') compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param('jobQueue') queue_name = self._get_param("jobQueue")
priority = self._get_param('priority') priority = self._get_param("priority")
state = self._get_param('state') state = self._get_param("state")
try: try:
name, arn = self.batch_backend.update_job_queue( name, arn = self.batch_backend.update_job_queue(
queue_name=queue_name, queue_name=queue_name,
priority=priority, priority=priority,
state=state, state=state,
compute_env_order=compute_env_order compute_env_order=compute_env_order,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"jobQueueArn": arn, "jobQueueName": name}
'jobQueueArn': arn,
'jobQueueName': name
}
return json.dumps(result) return json.dumps(result)
# DeleteJobQueue # DeleteJobQueue
def deletejobqueue(self): def deletejobqueue(self):
queue_name = self._get_param('jobQueue') queue_name = self._get_param("jobQueue")
self.batch_backend.delete_job_queue(queue_name) self.batch_backend.delete_job_queue(queue_name)
return '' return ""
# RegisterJobDefinition # RegisterJobDefinition
def registerjobdefinition(self): def registerjobdefinition(self):
container_properties = self._get_param('containerProperties') container_properties = self._get_param("containerProperties")
def_name = self._get_param('jobDefinitionName') def_name = self._get_param("jobDefinitionName")
parameters = self._get_param('parameters') parameters = self._get_param("parameters")
retry_strategy = self._get_param('retryStrategy') retry_strategy = self._get_param("retryStrategy")
_type = self._get_param('type') _type = self._get_param("type")
try: try:
name, arn, revision = self.batch_backend.register_job_definition( name, arn, revision = self.batch_backend.register_job_definition(
@ -193,104 +186,113 @@ class BatchResponse(BaseResponse):
parameters=parameters, parameters=parameters,
_type=_type, _type=_type,
retry_strategy=retry_strategy, retry_strategy=retry_strategy,
container_properties=container_properties container_properties=container_properties,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {
'jobDefinitionArn': arn, "jobDefinitionArn": arn,
'jobDefinitionName': name, "jobDefinitionName": name,
'revision': revision "revision": revision,
} }
return json.dumps(result) return json.dumps(result)
# DeregisterJobDefinition # DeregisterJobDefinition
def deregisterjobdefinition(self): def deregisterjobdefinition(self):
queue_name = self._get_param('jobDefinition') queue_name = self._get_param("jobDefinition")
self.batch_backend.deregister_job_definition(queue_name) self.batch_backend.deregister_job_definition(queue_name)
return '' return ""
# DescribeJobDefinitions # DescribeJobDefinitions
def describejobdefinitions(self): def describejobdefinitions(self):
job_def_name = self._get_param('jobDefinitionName') job_def_name = self._get_param("jobDefinitionName")
job_def_list = self._get_param('jobDefinitions') job_def_list = self._get_param("jobDefinitions")
max_results = self._get_param('maxResults') max_results = self._get_param("maxResults")
next_token = self._get_param('nextToken') next_token = self._get_param("nextToken")
status = self._get_param('status') status = self._get_param("status")
job_defs = self.batch_backend.describe_job_definitions(job_def_name, job_def_list, status, max_results, next_token) job_defs = self.batch_backend.describe_job_definitions(
job_def_name, job_def_list, status, max_results, next_token
)
result = {'jobDefinitions': [job.describe() for job in job_defs]} result = {"jobDefinitions": [job.describe() for job in job_defs]}
return json.dumps(result) return json.dumps(result)
# SubmitJob # SubmitJob
def submitjob(self): def submitjob(self):
container_overrides = self._get_param('containerOverrides') container_overrides = self._get_param("containerOverrides")
depends_on = self._get_param('dependsOn') depends_on = self._get_param("dependsOn")
job_def = self._get_param('jobDefinition') job_def = self._get_param("jobDefinition")
job_name = self._get_param('jobName') job_name = self._get_param("jobName")
job_queue = self._get_param('jobQueue') job_queue = self._get_param("jobQueue")
parameters = self._get_param('parameters') parameters = self._get_param("parameters")
retries = self._get_param('retryStrategy') retries = self._get_param("retryStrategy")
try: try:
name, job_id = self.batch_backend.submit_job( name, job_id = self.batch_backend.submit_job(
job_name, job_def, job_queue, job_name,
job_def,
job_queue,
parameters=parameters, parameters=parameters,
retries=retries, retries=retries,
depends_on=depends_on, depends_on=depends_on,
container_overrides=container_overrides container_overrides=container_overrides,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"jobId": job_id, "jobName": name}
'jobId': job_id,
'jobName': name,
}
return json.dumps(result) return json.dumps(result)
# DescribeJobs # DescribeJobs
def describejobs(self): def describejobs(self):
jobs = self._get_param('jobs') jobs = self._get_param("jobs")
try: try:
return json.dumps({'jobs': self.batch_backend.describe_jobs(jobs)}) return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
except AWSError as err: except AWSError as err:
return err.response() return err.response()
# ListJobs # ListJobs
def listjobs(self): def listjobs(self):
job_queue = self._get_param('jobQueue') job_queue = self._get_param("jobQueue")
job_status = self._get_param('jobStatus') job_status = self._get_param("jobStatus")
max_results = self._get_param('maxResults') max_results = self._get_param("maxResults")
next_token = self._get_param('nextToken') next_token = self._get_param("nextToken")
try: try:
jobs = self.batch_backend.list_jobs(job_queue, job_status, max_results, next_token) jobs = self.batch_backend.list_jobs(
job_queue, job_status, max_results, next_token
)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = {'jobSummaryList': [{'jobId': job.job_id, 'jobName': job.job_name} for job in jobs]} result = {
"jobSummaryList": [
{"jobId": job.job_id, "jobName": job.job_name} for job in jobs
]
}
return json.dumps(result) return json.dumps(result)
# TerminateJob # TerminateJob
def terminatejob(self): def terminatejob(self):
job_id = self._get_param('jobId') job_id = self._get_param("jobId")
reason = self._get_param('reason') reason = self._get_param("reason")
try: try:
self.batch_backend.terminate_job(job_id, reason) self.batch_backend.terminate_job(job_id, reason)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
# CancelJob # CancelJob
def canceljob(self): # Theres some AWS semantics on the differences but for us they're identical ;-) def canceljob(
self,
): # Theres some AWS semantics on the differences but for us they're identical ;-)
return self.terminatejob() return self.terminatejob()

View File

@ -1,25 +1,23 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import BatchResponse from .responses import BatchResponse
url_bases = [ url_bases = ["https?://batch.(.+).amazonaws.com"]
"https?://batch.(.+).amazonaws.com",
]
url_paths = { url_paths = {
'{0}/v1/createcomputeenvironment$': BatchResponse.dispatch, "{0}/v1/createcomputeenvironment$": BatchResponse.dispatch,
'{0}/v1/describecomputeenvironments$': BatchResponse.dispatch, "{0}/v1/describecomputeenvironments$": BatchResponse.dispatch,
'{0}/v1/deletecomputeenvironment': BatchResponse.dispatch, "{0}/v1/deletecomputeenvironment": BatchResponse.dispatch,
'{0}/v1/updatecomputeenvironment': BatchResponse.dispatch, "{0}/v1/updatecomputeenvironment": BatchResponse.dispatch,
'{0}/v1/createjobqueue': BatchResponse.dispatch, "{0}/v1/createjobqueue": BatchResponse.dispatch,
'{0}/v1/describejobqueues': BatchResponse.dispatch, "{0}/v1/describejobqueues": BatchResponse.dispatch,
'{0}/v1/updatejobqueue': BatchResponse.dispatch, "{0}/v1/updatejobqueue": BatchResponse.dispatch,
'{0}/v1/deletejobqueue': BatchResponse.dispatch, "{0}/v1/deletejobqueue": BatchResponse.dispatch,
'{0}/v1/registerjobdefinition': BatchResponse.dispatch, "{0}/v1/registerjobdefinition": BatchResponse.dispatch,
'{0}/v1/deregisterjobdefinition': BatchResponse.dispatch, "{0}/v1/deregisterjobdefinition": BatchResponse.dispatch,
'{0}/v1/describejobdefinitions': BatchResponse.dispatch, "{0}/v1/describejobdefinitions": BatchResponse.dispatch,
'{0}/v1/submitjob': BatchResponse.dispatch, "{0}/v1/submitjob": BatchResponse.dispatch,
'{0}/v1/describejobs': BatchResponse.dispatch, "{0}/v1/describejobs": BatchResponse.dispatch,
'{0}/v1/listjobs': BatchResponse.dispatch, "{0}/v1/listjobs": BatchResponse.dispatch,
'{0}/v1/terminatejob': BatchResponse.dispatch, "{0}/v1/terminatejob": BatchResponse.dispatch,
'{0}/v1/canceljob': BatchResponse.dispatch, "{0}/v1/canceljob": BatchResponse.dispatch,
} }

View File

@ -2,7 +2,9 @@ from __future__ import unicode_literals
def make_arn_for_compute_env(account_id, name, region_name): def make_arn_for_compute_env(account_id, name, region_name):
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(region_name, account_id, name) return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(
region_name, account_id, name
)
def make_arn_for_job_queue(account_id, name, region_name): def make_arn_for_job_queue(account_id, name, region_name):
@ -10,7 +12,9 @@ def make_arn_for_job_queue(account_id, name, region_name):
def make_arn_for_task_def(account_id, name, revision, region_name): def make_arn_for_task_def(account_id, name, revision, region_name):
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(region_name, account_id, name, revision) return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(
region_name, account_id, name, revision
)
def lowercase_first_key(some_dict): def lowercase_first_key(some_dict):

View File

@ -2,7 +2,6 @@ from __future__ import unicode_literals
from .models import cloudformation_backends from .models import cloudformation_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
cloudformation_backend = cloudformation_backends['us-east-1'] cloudformation_backend = cloudformation_backends["us-east-1"]
mock_cloudformation = base_decorator(cloudformation_backends) mock_cloudformation = base_decorator(cloudformation_backends)
mock_cloudformation_deprecated = deprecated_base_decorator( mock_cloudformation_deprecated = deprecated_base_decorator(cloudformation_backends)
cloudformation_backends)

View File

@ -4,26 +4,23 @@ from jinja2 import Template
class UnformattedGetAttTemplateException(Exception): class UnformattedGetAttTemplateException(Exception):
description = 'Template error: resource {0} does not support attribute type {1} in Fn::GetAtt' description = (
"Template error: resource {0} does not support attribute type {1} in Fn::GetAtt"
)
status_code = 400 status_code = 400
class ValidationError(BadRequest): class ValidationError(BadRequest):
def __init__(self, name_or_id, message=None): def __init__(self, name_or_id, message=None):
if message is None: if message is None:
message = "Stack with id {0} does not exist".format(name_or_id) message = "Stack with id {0} does not exist".format(name_or_id)
template = Template(ERROR_RESPONSE) template = Template(ERROR_RESPONSE)
super(ValidationError, self).__init__() super(ValidationError, self).__init__()
self.description = template.render( self.description = template.render(code="ValidationError", message=message)
code="ValidationError",
message=message,
)
class MissingParameterError(BadRequest): class MissingParameterError(BadRequest):
def __init__(self, parameter_name): def __init__(self, parameter_name):
template = Template(ERROR_RESPONSE) template = Template(ERROR_RESPONSE)
super(MissingParameterError, self).__init__() super(MissingParameterError, self).__init__()
@ -40,8 +37,8 @@ class ExportNotFound(BadRequest):
template = Template(ERROR_RESPONSE) template = Template(ERROR_RESPONSE)
super(ExportNotFound, self).__init__() super(ExportNotFound, self).__init__()
self.description = template.render( self.description = template.render(
code='ExportNotFound', code="ExportNotFound",
message="No export named {0} found.".format(export_name) message="No export named {0} found.".format(export_name),
) )

View File

@ -4,7 +4,8 @@ import json
import yaml import yaml
import uuid import uuid
import boto.cloudformation from boto3 import Session
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
@ -21,11 +22,19 @@ from .exceptions import ValidationError
class FakeStackSet(BaseModel): class FakeStackSet(BaseModel):
def __init__(
def __init__(self, stackset_id, name, template, region='us-east-1', self,
status='ACTIVE', description=None, parameters=None, tags=None, stackset_id,
admin_role='AWSCloudFormationStackSetAdministrationRole', name,
execution_role='AWSCloudFormationStackSetExecutionRole'): template,
region="us-east-1",
status="ACTIVE",
description=None,
parameters=None,
tags=None,
admin_role="AWSCloudFormationStackSetAdministrationRole",
execution_role="AWSCloudFormationStackSetExecutionRole",
):
self.id = stackset_id self.id = stackset_id
self.arn = generate_stackset_arn(stackset_id, region) self.arn = generate_stackset_arn(stackset_id, region)
self.name = name self.name = name
@ -42,12 +51,14 @@ class FakeStackSet(BaseModel):
def _create_operation(self, operation_id, action, status, accounts=[], regions=[]): def _create_operation(self, operation_id, action, status, accounts=[], regions=[]):
operation = { operation = {
'OperationId': str(operation_id), "OperationId": str(operation_id),
'Action': action, "Action": action,
'Status': status, "Status": status,
'CreationTimestamp': datetime.now(), "CreationTimestamp": datetime.now(),
'EndTimestamp': datetime.now() + timedelta(minutes=2), "EndTimestamp": datetime.now() + timedelta(minutes=2),
'Instances': [{account: region} for account in accounts for region in regions], "Instances": [
{account: region} for account in accounts for region in regions
],
} }
self.operations += [operation] self.operations += [operation]
@ -55,20 +66,30 @@ class FakeStackSet(BaseModel):
def get_operation(self, operation_id): def get_operation(self, operation_id):
for operation in self.operations: for operation in self.operations:
if operation_id == operation['OperationId']: if operation_id == operation["OperationId"]:
return operation return operation
raise ValidationError(operation_id) raise ValidationError(operation_id)
def update_operation(self, operation_id, status): def update_operation(self, operation_id, status):
operation = self.get_operation(operation_id) operation = self.get_operation(operation_id)
operation['Status'] = status operation["Status"] = status
return operation_id return operation_id
def delete(self): def delete(self):
self.status = 'DELETED' self.status = "DELETED"
def update(self, template, description, parameters, tags, admin_role, def update(
execution_role, accounts, regions, operation_id=None): self,
template,
description,
parameters,
tags,
admin_role,
execution_role,
accounts,
regions,
operation_id=None,
):
if not operation_id: if not operation_id:
operation_id = uuid.uuid4() operation_id = uuid.uuid4()
@ -82,9 +103,13 @@ class FakeStackSet(BaseModel):
if accounts and regions: if accounts and regions:
self.update_instances(accounts, regions, self.parameters) self.update_instances(accounts, regions, self.parameters)
operation = self._create_operation(operation_id=operation_id, operation = self._create_operation(
action='UPDATE', status='SUCCEEDED', accounts=accounts, operation_id=operation_id,
regions=regions) action="UPDATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation return operation
def create_stack_instances(self, accounts, regions, parameters, operation_id=None): def create_stack_instances(self, accounts, regions, parameters, operation_id=None):
@ -94,8 +119,13 @@ class FakeStackSet(BaseModel):
parameters = self.parameters parameters = self.parameters
self.instances.create_instances(accounts, regions, parameters, operation_id) self.instances.create_instances(accounts, regions, parameters, operation_id)
self._create_operation(operation_id=operation_id, action='CREATE', self._create_operation(
status='SUCCEEDED', accounts=accounts, regions=regions) operation_id=operation_id,
action="CREATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
def delete_stack_instances(self, accounts, regions, operation_id=None): def delete_stack_instances(self, accounts, regions, operation_id=None):
if not operation_id: if not operation_id:
@ -103,8 +133,13 @@ class FakeStackSet(BaseModel):
self.instances.delete(accounts, regions) self.instances.delete(accounts, regions)
operation = self._create_operation(operation_id=operation_id, action='DELETE', operation = self._create_operation(
status='SUCCEEDED', accounts=accounts, regions=regions) operation_id=operation_id,
action="DELETE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation return operation
def update_instances(self, accounts, regions, parameters, operation_id=None): def update_instances(self, accounts, regions, parameters, operation_id=None):
@ -112,9 +147,13 @@ class FakeStackSet(BaseModel):
operation_id = uuid.uuid4() operation_id = uuid.uuid4()
self.instances.update(accounts, regions, parameters) self.instances.update(accounts, regions, parameters)
operation = self._create_operation(operation_id=operation_id, operation = self._create_operation(
action='UPDATE', status='SUCCEEDED', accounts=accounts, operation_id=operation_id,
regions=regions) action="UPDATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation return operation
@ -131,12 +170,12 @@ class FakeStackInstances(BaseModel):
for region in regions: for region in regions:
for account in accounts: for account in accounts:
instance = { instance = {
'StackId': generate_stack_id(self.stack_name, region, account), "StackId": generate_stack_id(self.stack_name, region, account),
'StackSetId': self.stackset_id, "StackSetId": self.stackset_id,
'Region': region, "Region": region,
'Account': account, "Account": account,
'Status': "CURRENT", "Status": "CURRENT",
'ParameterOverrides': parameters if parameters else [], "ParameterOverrides": parameters if parameters else [],
} }
new_instances.append(instance) new_instances.append(instance)
self.stack_instances += new_instances self.stack_instances += new_instances
@ -147,24 +186,35 @@ class FakeStackInstances(BaseModel):
for region in regions: for region in regions:
instance = self.get_instance(account, region) instance = self.get_instance(account, region)
if parameters: if parameters:
instance['ParameterOverrides'] = parameters instance["ParameterOverrides"] = parameters
else: else:
instance['ParameterOverrides'] = [] instance["ParameterOverrides"] = []
def delete(self, accounts, regions): def delete(self, accounts, regions):
for i, instance in enumerate(self.stack_instances): for i, instance in enumerate(self.stack_instances):
if instance['Region'] in regions and instance['Account'] in accounts: if instance["Region"] in regions and instance["Account"] in accounts:
self.stack_instances.pop(i) self.stack_instances.pop(i)
def get_instance(self, account, region): def get_instance(self, account, region):
for i, instance in enumerate(self.stack_instances): for i, instance in enumerate(self.stack_instances):
if instance['Region'] == region and instance['Account'] == account: if instance["Region"] == region and instance["Account"] == account:
return self.stack_instances[i] return self.stack_instances[i]
class FakeStack(BaseModel): class FakeStack(BaseModel):
def __init__(
def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None, create_change_set=False): self,
stack_id,
name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
cross_stack_resources=None,
create_change_set=False,
):
self.stack_id = stack_id self.stack_id = stack_id
self.name = name self.name = name
self.template = template self.template = template
@ -176,22 +226,31 @@ class FakeStack(BaseModel):
self.tags = tags if tags else {} self.tags = tags if tags else {}
self.events = [] self.events = []
if create_change_set: if create_change_set:
self._add_stack_event("REVIEW_IN_PROGRESS", self._add_stack_event(
resource_status_reason="User Initiated") "REVIEW_IN_PROGRESS", resource_status_reason="User Initiated"
)
else: else:
self._add_stack_event("CREATE_IN_PROGRESS", self._add_stack_event(
resource_status_reason="User Initiated") "CREATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.description = self.template_dict.get('Description') self.description = self.template_dict.get("Description")
self.cross_stack_resources = cross_stack_resources or {} self.cross_stack_resources = cross_stack_resources or {}
self.resource_map = self._create_resource_map() self.resource_map = self._create_resource_map()
self.output_map = self._create_output_map() self.output_map = self._create_output_map()
self._add_stack_event("CREATE_COMPLETE") self._add_stack_event("CREATE_COMPLETE")
self.status = 'CREATE_COMPLETE' self.status = "CREATE_COMPLETE"
def _create_resource_map(self): def _create_resource_map(self):
resource_map = ResourceMap( resource_map = ResourceMap(
self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict, self.cross_stack_resources) self.stack_id,
self.name,
self.parameters,
self.tags,
self.region_name,
self.template_dict,
self.cross_stack_resources,
)
resource_map.create() resource_map.create()
return resource_map return resource_map
@ -200,34 +259,46 @@ class FakeStack(BaseModel):
output_map.create() output_map.create()
return output_map return output_map
def _add_stack_event(self, resource_status, resource_status_reason=None, resource_properties=None): def _add_stack_event(
self.events.append(FakeEvent( self, resource_status, resource_status_reason=None, resource_properties=None
stack_id=self.stack_id, ):
stack_name=self.name, self.events.append(
logical_resource_id=self.name, FakeEvent(
physical_resource_id=self.stack_id, stack_id=self.stack_id,
resource_type="AWS::CloudFormation::Stack", stack_name=self.name,
resource_status=resource_status, logical_resource_id=self.name,
resource_status_reason=resource_status_reason, physical_resource_id=self.stack_id,
resource_properties=resource_properties, resource_type="AWS::CloudFormation::Stack",
)) resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
)
)
def _add_resource_event(self, logical_resource_id, resource_status, resource_status_reason=None, resource_properties=None): def _add_resource_event(
self,
logical_resource_id,
resource_status,
resource_status_reason=None,
resource_properties=None,
):
# not used yet... feel free to help yourself # not used yet... feel free to help yourself
resource = self.resource_map[logical_resource_id] resource = self.resource_map[logical_resource_id]
self.events.append(FakeEvent( self.events.append(
stack_id=self.stack_id, FakeEvent(
stack_name=self.name, stack_id=self.stack_id,
logical_resource_id=logical_resource_id, stack_name=self.name,
physical_resource_id=resource.physical_resource_id, logical_resource_id=logical_resource_id,
resource_type=resource.type, physical_resource_id=resource.physical_resource_id,
resource_status=resource_status, resource_type=resource.type,
resource_status_reason=resource_status_reason, resource_status=resource_status,
resource_properties=resource_properties, resource_status_reason=resource_status_reason,
)) resource_properties=resource_properties,
)
)
def _parse_template(self): def _parse_template(self):
yaml.add_multi_constructor('', yaml_tag_constructor) yaml.add_multi_constructor("", yaml_tag_constructor)
try: try:
self.template_dict = yaml.load(self.template, Loader=yaml.Loader) self.template_dict = yaml.load(self.template, Loader=yaml.Loader)
except yaml.parser.ParserError: except yaml.parser.ParserError:
@ -250,7 +321,9 @@ class FakeStack(BaseModel):
return self.output_map.exports return self.output_map.exports
def update(self, template, role_arn=None, parameters=None, tags=None): def update(self, template, role_arn=None, parameters=None, tags=None):
self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated") self._add_stack_event(
"UPDATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.template = template self.template = template
self._parse_template() self._parse_template()
self.resource_map.update(self.template_dict, parameters) self.resource_map.update(self.template_dict, parameters)
@ -264,15 +337,15 @@ class FakeStack(BaseModel):
# TODO: update tags in the resource map # TODO: update tags in the resource map
def delete(self): def delete(self):
self._add_stack_event("DELETE_IN_PROGRESS", self._add_stack_event(
resource_status_reason="User Initiated") "DELETE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.resource_map.delete() self.resource_map.delete()
self._add_stack_event("DELETE_COMPLETE") self._add_stack_event("DELETE_COMPLETE")
self.status = "DELETE_COMPLETE" self.status = "DELETE_COMPLETE"
class FakeChange(BaseModel): class FakeChange(BaseModel):
def __init__(self, action, logical_resource_id, resource_type): def __init__(self, action, logical_resource_id, resource_type):
self.action = action self.action = action
self.logical_resource_id = logical_resource_id self.logical_resource_id = logical_resource_id
@ -280,8 +353,21 @@ class FakeChange(BaseModel):
class FakeChangeSet(FakeStack): class FakeChangeSet(FakeStack):
def __init__(
def __init__(self, stack_id, stack_name, stack_template, change_set_id, change_set_name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None): self,
stack_id,
stack_name,
stack_template,
change_set_id,
change_set_name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
cross_stack_resources=None,
):
super(FakeChangeSet, self).__init__( super(FakeChangeSet, self).__init__(
stack_id, stack_id,
stack_name, stack_name,
@ -306,17 +392,28 @@ class FakeChangeSet(FakeStack):
resources_by_action = self.resource_map.diff(self.template_dict, parameters) resources_by_action = self.resource_map.diff(self.template_dict, parameters)
for action, resources in resources_by_action.items(): for action, resources in resources_by_action.items():
for resource_name, resource in resources.items(): for resource_name, resource in resources.items():
changes.append(FakeChange( changes.append(
action=action, FakeChange(
logical_resource_id=resource_name, action=action,
resource_type=resource['ResourceType'], logical_resource_id=resource_name,
)) resource_type=resource["ResourceType"],
)
)
return changes return changes
class FakeEvent(BaseModel): class FakeEvent(BaseModel):
def __init__(
def __init__(self, stack_id, stack_name, logical_resource_id, physical_resource_id, resource_type, resource_status, resource_status_reason=None, resource_properties=None): self,
stack_id,
stack_name,
logical_resource_id,
physical_resource_id,
resource_type,
resource_status,
resource_status_reason=None,
resource_properties=None,
):
self.stack_id = stack_id self.stack_id = stack_id
self.stack_name = stack_name self.stack_name = stack_name
self.logical_resource_id = logical_resource_id self.logical_resource_id = logical_resource_id
@ -330,7 +427,6 @@ class FakeEvent(BaseModel):
class CloudFormationBackend(BaseBackend): class CloudFormationBackend(BaseBackend):
def __init__(self): def __init__(self):
self.stacks = OrderedDict() self.stacks = OrderedDict()
self.stacksets = OrderedDict() self.stacksets = OrderedDict()
@ -338,7 +434,17 @@ class CloudFormationBackend(BaseBackend):
self.exports = OrderedDict() self.exports = OrderedDict()
self.change_sets = OrderedDict() self.change_sets = OrderedDict()
def create_stack_set(self, name, template, parameters, tags=None, description=None, region='us-east-1', admin_role=None, execution_role=None): def create_stack_set(
self,
name,
template,
parameters,
tags=None,
description=None,
region="us-east-1",
admin_role=None,
execution_role=None,
):
stackset_id = generate_stackset_id(name) stackset_id = generate_stackset_id(name)
new_stackset = FakeStackSet( new_stackset = FakeStackSet(
stackset_id=stackset_id, stackset_id=stackset_id,
@ -366,7 +472,9 @@ class CloudFormationBackend(BaseBackend):
if self.stacksets[stackset].name == name: if self.stacksets[stackset].name == name:
self.stacksets[stackset].delete() self.stacksets[stackset].delete()
def create_stack_instances(self, stackset_name, accounts, regions, parameters, operation_id=None): def create_stack_instances(
self, stackset_name, accounts, regions, parameters, operation_id=None
):
stackset = self.get_stack_set(stackset_name) stackset = self.get_stack_set(stackset_name)
stackset.create_stack_instances( stackset.create_stack_instances(
@ -377,9 +485,19 @@ class CloudFormationBackend(BaseBackend):
) )
return stackset return stackset
def update_stack_set(self, stackset_name, template=None, description=None, def update_stack_set(
parameters=None, tags=None, admin_role=None, execution_role=None, self,
accounts=None, regions=None, operation_id=None): stackset_name,
template=None,
description=None,
parameters=None,
tags=None,
admin_role=None,
execution_role=None,
accounts=None,
regions=None,
operation_id=None,
):
stackset = self.get_stack_set(stackset_name) stackset = self.get_stack_set(stackset_name)
update = stackset.update( update = stackset.update(
template=template, template=template,
@ -390,16 +508,28 @@ class CloudFormationBackend(BaseBackend):
execution_role=execution_role, execution_role=execution_role,
accounts=accounts, accounts=accounts,
regions=regions, regions=regions,
operation_id=operation_id operation_id=operation_id,
) )
return update return update
def delete_stack_instances(self, stackset_name, accounts, regions, operation_id=None): def delete_stack_instances(
self, stackset_name, accounts, regions, operation_id=None
):
stackset = self.get_stack_set(stackset_name) stackset = self.get_stack_set(stackset_name)
stackset.delete_stack_instances(accounts, regions, operation_id) stackset.delete_stack_instances(accounts, regions, operation_id)
return stackset return stackset
def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False): def create_stack(
self,
name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
create_change_set=False,
):
stack_id = generate_stack_id(name) stack_id = generate_stack_id(name)
new_stack = FakeStack( new_stack = FakeStack(
stack_id=stack_id, stack_id=stack_id,
@ -419,10 +549,21 @@ class CloudFormationBackend(BaseBackend):
self.exports[export.name] = export self.exports[export.name] = export
return new_stack return new_stack
def create_change_set(self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None): def create_change_set(
self,
stack_name,
change_set_name,
template,
parameters,
region_name,
change_set_type,
notification_arns=None,
tags=None,
role_arn=None,
):
stack_id = None stack_id = None
stack_template = None stack_template = None
if change_set_type == 'UPDATE': if change_set_type == "UPDATE":
stacks = self.stacks.values() stacks = self.stacks.values()
stack = None stack = None
for s in stacks: for s in stacks:
@ -449,7 +590,7 @@ class CloudFormationBackend(BaseBackend):
notification_arns=notification_arns, notification_arns=notification_arns,
tags=tags, tags=tags,
role_arn=role_arn, role_arn=role_arn,
cross_stack_resources=self.exports cross_stack_resources=self.exports,
) )
self.change_sets[change_set_id] = new_change_set self.change_sets[change_set_id] = new_change_set
self.stacks[stack_id] = new_change_set self.stacks[stack_id] = new_change_set
@ -488,11 +629,11 @@ class CloudFormationBackend(BaseBackend):
stack = self.change_sets[cs] stack = self.change_sets[cs]
if stack is None: if stack is None:
raise ValidationError(stack_name) raise ValidationError(stack_name)
if stack.events[-1].resource_status == 'REVIEW_IN_PROGRESS': if stack.events[-1].resource_status == "REVIEW_IN_PROGRESS":
stack._add_stack_event('CREATE_COMPLETE') stack._add_stack_event("CREATE_COMPLETE")
else: else:
stack._add_stack_event('UPDATE_IN_PROGRESS') stack._add_stack_event("UPDATE_IN_PROGRESS")
stack._add_stack_event('UPDATE_COMPLETE') stack._add_stack_event("UPDATE_COMPLETE")
return True return True
def describe_stacks(self, name_or_stack_id): def describe_stacks(self, name_or_stack_id):
@ -514,9 +655,7 @@ class CloudFormationBackend(BaseBackend):
return self.change_sets.values() return self.change_sets.values()
def list_stacks(self): def list_stacks(self):
return [ return [v for v in self.stacks.values()] + [
v for v in self.stacks.values()
] + [
v for v in self.deleted_stacks.values() v for v in self.deleted_stacks.values()
] ]
@ -558,10 +697,10 @@ class CloudFormationBackend(BaseBackend):
all_exports = list(self.exports.values()) all_exports = list(self.exports.values())
if token is None: if token is None:
exports = all_exports[0:100] exports = all_exports[0:100]
next_token = '100' if len(all_exports) > 100 else None next_token = "100" if len(all_exports) > 100 else None
else: else:
token = int(token) token = int(token)
exports = all_exports[token:token + 100] exports = all_exports[token : token + 100]
next_token = str(token + 100) if len(all_exports) > token + 100 else None next_token = str(token + 100) if len(all_exports) > token + 100 else None
return exports, next_token return exports, next_token
@ -572,9 +711,20 @@ class CloudFormationBackend(BaseBackend):
new_stack_export_names = [x.name for x in stack.exports] new_stack_export_names = [x.name for x in stack.exports]
export_names = self.exports.keys() export_names = self.exports.keys()
if not set(export_names).isdisjoint(new_stack_export_names): if not set(export_names).isdisjoint(new_stack_export_names):
raise ValidationError(stack.stack_id, message='Export names must be unique across a given region') raise ValidationError(
stack.stack_id,
message="Export names must be unique across a given region",
)
cloudformation_backends = {} cloudformation_backends = {}
for region in boto.cloudformation.regions(): for region in Session().get_available_regions("cloudformation"):
cloudformation_backends[region.name] = CloudFormationBackend() cloudformation_backends[region] = CloudFormationBackend()
for region in Session().get_available_regions(
"cloudformation", partition_name="aws-us-gov"
):
cloudformation_backends[region] = CloudFormationBackend()
for region in Session().get_available_regions(
"cloudformation", partition_name="aws-cn"
):
cloudformation_backends[region] = CloudFormationBackend()

View File

@ -1,5 +1,4 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import collections
import functools import functools
import logging import logging
import copy import copy
@ -11,8 +10,9 @@ from moto.awslambda import models as lambda_models
from moto.batch import models as batch_models from moto.batch import models as batch_models
from moto.cloudwatch import models as cloudwatch_models from moto.cloudwatch import models as cloudwatch_models
from moto.cognitoidentity import models as cognitoidentity_models from moto.cognitoidentity import models as cognitoidentity_models
from moto.compat import collections_abc
from moto.datapipeline import models as datapipeline_models from moto.datapipeline import models as datapipeline_models
from moto.dynamodb import models as dynamodb_models from moto.dynamodb2 import models as dynamodb2_models
from moto.ec2 import models as ec2_models from moto.ec2 import models as ec2_models
from moto.ecs import models as ecs_models from moto.ecs import models as ecs_models
from moto.elb import models as elb_models from moto.elb import models as elb_models
@ -27,8 +27,14 @@ from moto.route53 import models as route53_models
from moto.s3 import models as s3_models from moto.s3 import models as s3_models
from moto.sns import models as sns_models from moto.sns import models as sns_models
from moto.sqs import models as sqs_models from moto.sqs import models as sqs_models
from moto.core import ACCOUNT_ID
from .utils import random_suffix from .utils import random_suffix
from .exceptions import ExportNotFound, MissingParameterError, UnformattedGetAttTemplateException, ValidationError from .exceptions import (
ExportNotFound,
MissingParameterError,
UnformattedGetAttTemplateException,
ValidationError,
)
from boto.cloudformation.stack import Output from boto.cloudformation.stack import Output
MODEL_MAP = { MODEL_MAP = {
@ -37,7 +43,7 @@ MODEL_MAP = {
"AWS::Batch::JobDefinition": batch_models.JobDefinition, "AWS::Batch::JobDefinition": batch_models.JobDefinition,
"AWS::Batch::JobQueue": batch_models.JobQueue, "AWS::Batch::JobQueue": batch_models.JobQueue,
"AWS::Batch::ComputeEnvironment": batch_models.ComputeEnvironment, "AWS::Batch::ComputeEnvironment": batch_models.ComputeEnvironment,
"AWS::DynamoDB::Table": dynamodb_models.Table, "AWS::DynamoDB::Table": dynamodb2_models.Table,
"AWS::Kinesis::Stream": kinesis_models.Stream, "AWS::Kinesis::Stream": kinesis_models.Stream,
"AWS::Lambda::EventSourceMapping": lambda_models.EventSourceMapping, "AWS::Lambda::EventSourceMapping": lambda_models.EventSourceMapping,
"AWS::Lambda::Function": lambda_models.LambdaFunction, "AWS::Lambda::Function": lambda_models.LambdaFunction,
@ -100,7 +106,7 @@ NAME_TYPE_MAP = {
"AWS::RDS::DBInstance": "DBInstanceIdentifier", "AWS::RDS::DBInstance": "DBInstanceIdentifier",
"AWS::S3::Bucket": "BucketName", "AWS::S3::Bucket": "BucketName",
"AWS::SNS::Topic": "TopicName", "AWS::SNS::Topic": "TopicName",
"AWS::SQS::Queue": "QueueName" "AWS::SQS::Queue": "QueueName",
} }
# Just ignore these models types for now # Just ignore these models types for now
@ -109,13 +115,12 @@ NULL_MODELS = [
"AWS::CloudFormation::WaitConditionHandle", "AWS::CloudFormation::WaitConditionHandle",
] ]
DEFAULT_REGION = 'us-east-1' DEFAULT_REGION = "us-east-1"
logger = logging.getLogger("moto") logger = logging.getLogger("moto")
class LazyDict(dict): class LazyDict(dict):
def __getitem__(self, key): def __getitem__(self, key):
val = dict.__getitem__(self, key) val = dict.__getitem__(self, key)
if callable(val): if callable(val):
@ -132,10 +137,10 @@ def clean_json(resource_json, resources_map):
Eventually, this is where we would add things like function parsing (fn::) Eventually, this is where we would add things like function parsing (fn::)
""" """
if isinstance(resource_json, dict): if isinstance(resource_json, dict):
if 'Ref' in resource_json: if "Ref" in resource_json:
# Parse resource reference # Parse resource reference
resource = resources_map[resource_json['Ref']] resource = resources_map[resource_json["Ref"]]
if hasattr(resource, 'physical_resource_id'): if hasattr(resource, "physical_resource_id"):
return resource.physical_resource_id return resource.physical_resource_id
else: else:
return resource return resource
@ -148,74 +153,92 @@ def clean_json(resource_json, resources_map):
result = result[clean_json(path, resources_map)] result = result[clean_json(path, resources_map)]
return result return result
if 'Fn::GetAtt' in resource_json: if "Fn::GetAtt" in resource_json:
resource = resources_map.get(resource_json['Fn::GetAtt'][0]) resource = resources_map.get(resource_json["Fn::GetAtt"][0])
if resource is None: if resource is None:
return resource_json return resource_json
try: try:
return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1]) return resource.get_cfn_attribute(resource_json["Fn::GetAtt"][1])
except NotImplementedError as n: except NotImplementedError as n:
logger.warning(str(n).format( logger.warning(str(n).format(resource_json["Fn::GetAtt"][0]))
resource_json['Fn::GetAtt'][0]))
except UnformattedGetAttTemplateException: except UnformattedGetAttTemplateException:
raise ValidationError( raise ValidationError(
'Bad Request', "Bad Request",
UnformattedGetAttTemplateException.description.format( UnformattedGetAttTemplateException.description.format(
resource_json['Fn::GetAtt'][0], resource_json['Fn::GetAtt'][1])) resource_json["Fn::GetAtt"][0], resource_json["Fn::GetAtt"][1]
),
)
if 'Fn::If' in resource_json: if "Fn::If" in resource_json:
condition_name, true_value, false_value = resource_json['Fn::If'] condition_name, true_value, false_value = resource_json["Fn::If"]
if resources_map.lazy_condition_map[condition_name]: if resources_map.lazy_condition_map[condition_name]:
return clean_json(true_value, resources_map) return clean_json(true_value, resources_map)
else: else:
return clean_json(false_value, resources_map) return clean_json(false_value, resources_map)
if 'Fn::Join' in resource_json: if "Fn::Join" in resource_json:
join_list = clean_json(resource_json['Fn::Join'][1], resources_map) join_list = clean_json(resource_json["Fn::Join"][1], resources_map)
return resource_json['Fn::Join'][0].join([str(x) for x in join_list]) return resource_json["Fn::Join"][0].join([str(x) for x in join_list])
if 'Fn::Split' in resource_json: if "Fn::Split" in resource_json:
to_split = clean_json(resource_json['Fn::Split'][1], resources_map) to_split = clean_json(resource_json["Fn::Split"][1], resources_map)
return to_split.split(resource_json['Fn::Split'][0]) return to_split.split(resource_json["Fn::Split"][0])
if 'Fn::Select' in resource_json: if "Fn::Select" in resource_json:
select_index = int(resource_json['Fn::Select'][0]) select_index = int(resource_json["Fn::Select"][0])
select_list = clean_json(resource_json['Fn::Select'][1], resources_map) select_list = clean_json(resource_json["Fn::Select"][1], resources_map)
return select_list[select_index] return select_list[select_index]
if 'Fn::Sub' in resource_json: if "Fn::Sub" in resource_json:
if isinstance(resource_json['Fn::Sub'], list): if isinstance(resource_json["Fn::Sub"], list):
warnings.warn( warnings.warn(
"Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation") "Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation"
)
else: else:
fn_sub_value = clean_json(resource_json['Fn::Sub'], resources_map) fn_sub_value = clean_json(resource_json["Fn::Sub"], resources_map)
to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value) to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value)
literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value) literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value)
for sub in to_sub: for sub in to_sub:
if '.' in sub: if "." in sub:
cleaned_ref = clean_json({'Fn::GetAtt': re.findall('(?<=\${)[^"]*?(?=})', sub)[0].split('.')}, resources_map) cleaned_ref = clean_json(
{
"Fn::GetAtt": re.findall('(?<=\${)[^"]*?(?=})', sub)[
0
].split(".")
},
resources_map,
)
else: else:
cleaned_ref = clean_json({'Ref': re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, resources_map) cleaned_ref = clean_json(
{"Ref": re.findall('(?<=\${)[^"]*?(?=})', sub)[0]},
resources_map,
)
fn_sub_value = fn_sub_value.replace(sub, cleaned_ref) fn_sub_value = fn_sub_value.replace(sub, cleaned_ref)
for literal in literals: for literal in literals:
fn_sub_value = fn_sub_value.replace(literal, literal.replace('!', '')) fn_sub_value = fn_sub_value.replace(
literal, literal.replace("!", "")
)
return fn_sub_value return fn_sub_value
pass pass
if 'Fn::ImportValue' in resource_json: if "Fn::ImportValue" in resource_json:
cleaned_val = clean_json(resource_json['Fn::ImportValue'], resources_map) cleaned_val = clean_json(resource_json["Fn::ImportValue"], resources_map)
values = [x.value for x in resources_map.cross_stack_resources.values() if x.name == cleaned_val] values = [
x.value
for x in resources_map.cross_stack_resources.values()
if x.name == cleaned_val
]
if any(values): if any(values):
return values[0] return values[0]
else: else:
raise ExportNotFound(cleaned_val) raise ExportNotFound(cleaned_val)
if 'Fn::GetAZs' in resource_json: if "Fn::GetAZs" in resource_json:
region = resource_json.get('Fn::GetAZs') or DEFAULT_REGION region = resource_json.get("Fn::GetAZs") or DEFAULT_REGION
result = [] result = []
# TODO: make this configurable, to reflect the real AWS AZs # TODO: make this configurable, to reflect the real AWS AZs
for az in ('a', 'b', 'c', 'd'): for az in ("a", "b", "c", "d"):
result.append('%s%s' % (region, az)) result.append("%s%s" % (region, az))
return result return result
cleaned_json = {} cleaned_json = {}
@ -246,57 +269,69 @@ def resource_name_property_from_type(resource_type):
def generate_resource_name(resource_type, stack_name, logical_id): def generate_resource_name(resource_type, stack_name, logical_id):
if resource_type == "AWS::ElasticLoadBalancingV2::TargetGroup": if resource_type in [
"AWS::ElasticLoadBalancingV2::TargetGroup",
"AWS::ElasticLoadBalancingV2::LoadBalancer",
]:
# Target group names need to be less than 32 characters, so when cloudformation creates a name for you # Target group names need to be less than 32 characters, so when cloudformation creates a name for you
# it makes sure to stay under that limit # it makes sure to stay under that limit
name_prefix = '{0}-{1}'.format(stack_name, logical_id) name_prefix = "{0}-{1}".format(stack_name, logical_id)
my_random_suffix = random_suffix() my_random_suffix = random_suffix()
truncated_name_prefix = name_prefix[0:32 - (len(my_random_suffix) + 1)] truncated_name_prefix = name_prefix[0 : 32 - (len(my_random_suffix) + 1)]
# if the truncated name ends in a dash, we'll end up with a double dash in the final name, which is # if the truncated name ends in a dash, we'll end up with a double dash in the final name, which is
# not allowed # not allowed
if truncated_name_prefix.endswith('-'): if truncated_name_prefix.endswith("-"):
truncated_name_prefix = truncated_name_prefix[:-1] truncated_name_prefix = truncated_name_prefix[:-1]
return '{0}-{1}'.format(truncated_name_prefix, my_random_suffix) return "{0}-{1}".format(truncated_name_prefix, my_random_suffix)
else: else:
return '{0}-{1}-{2}'.format(stack_name, logical_id, random_suffix()) return "{0}-{1}-{2}".format(stack_name, logical_id, random_suffix())
def parse_resource(logical_id, resource_json, resources_map): def parse_resource(logical_id, resource_json, resources_map):
resource_type = resource_json['Type'] resource_type = resource_json["Type"]
resource_class = resource_class_from_type(resource_type) resource_class = resource_class_from_type(resource_type)
if not resource_class: if not resource_class:
warnings.warn( warnings.warn(
"Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(resource_type)) "Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(
resource_type
)
)
return None return None
resource_json = clean_json(resource_json, resources_map) resource_json = clean_json(resource_json, resources_map)
resource_name_property = resource_name_property_from_type(resource_type) resource_name_property = resource_name_property_from_type(resource_type)
if resource_name_property: if resource_name_property:
if 'Properties' not in resource_json: if "Properties" not in resource_json:
resource_json['Properties'] = dict() resource_json["Properties"] = dict()
if resource_name_property not in resource_json['Properties']: if resource_name_property not in resource_json["Properties"]:
resource_json['Properties'][resource_name_property] = generate_resource_name( resource_json["Properties"][
resource_type, resources_map.get('AWS::StackName'), logical_id) resource_name_property
resource_name = resource_json['Properties'][resource_name_property] ] = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
resource_name = resource_json["Properties"][resource_name_property]
else: else:
resource_name = generate_resource_name(resource_type, resources_map.get('AWS::StackName'), logical_id) resource_name = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
return resource_class, resource_json, resource_name return resource_class, resource_json, resource_name
def parse_and_create_resource(logical_id, resource_json, resources_map, region_name): def parse_and_create_resource(logical_id, resource_json, resources_map, region_name):
condition = resource_json.get('Condition') condition = resource_json.get("Condition")
if condition and not resources_map.lazy_condition_map[condition]: if condition and not resources_map.lazy_condition_map[condition]:
# If this has a False condition, don't create the resource # If this has a False condition, don't create the resource
return None return None
resource_type = resource_json['Type'] resource_type = resource_json["Type"]
resource_tuple = parse_resource(logical_id, resource_json, resources_map) resource_tuple = parse_resource(logical_id, resource_json, resources_map)
if not resource_tuple: if not resource_tuple:
return None return None
resource_class, resource_json, resource_name = resource_tuple resource_class, resource_json, resource_name = resource_tuple
resource = resource_class.create_from_cloudformation_json( resource = resource_class.create_from_cloudformation_json(
resource_name, resource_json, region_name) resource_name, resource_json, region_name
)
resource.type = resource_type resource.type = resource_type
resource.logical_resource_id = logical_id resource.logical_resource_id = logical_id
return resource return resource
@ -304,24 +339,27 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n
def parse_and_update_resource(logical_id, resource_json, resources_map, region_name): def parse_and_update_resource(logical_id, resource_json, resources_map, region_name):
resource_class, new_resource_json, new_resource_name = parse_resource( resource_class, new_resource_json, new_resource_name = parse_resource(
logical_id, resource_json, resources_map) logical_id, resource_json, resources_map
)
original_resource = resources_map[logical_id] original_resource = resources_map[logical_id]
new_resource = resource_class.update_from_cloudformation_json( new_resource = resource_class.update_from_cloudformation_json(
original_resource=original_resource, original_resource=original_resource,
new_resource_name=new_resource_name, new_resource_name=new_resource_name,
cloudformation_json=new_resource_json, cloudformation_json=new_resource_json,
region_name=region_name region_name=region_name,
) )
new_resource.type = resource_json['Type'] new_resource.type = resource_json["Type"]
new_resource.logical_resource_id = logical_id new_resource.logical_resource_id = logical_id
return new_resource return new_resource
def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name): def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name):
resource_class, resource_json, resource_name = parse_resource( resource_class, resource_json, resource_name = parse_resource(
logical_id, resource_json, resources_map) logical_id, resource_json, resources_map
)
resource_class.delete_from_cloudformation_json( resource_class.delete_from_cloudformation_json(
resource_name, resource_json, region_name) resource_name, resource_json, region_name
)
def parse_condition(condition, resources_map, condition_map): def parse_condition(condition, resources_map, condition_map):
@ -333,8 +371,8 @@ def parse_condition(condition, resources_map, condition_map):
condition_values = [] condition_values = []
for value in list(condition.values())[0]: for value in list(condition.values())[0]:
# Check if we are referencing another Condition # Check if we are referencing another Condition
if 'Condition' in value: if "Condition" in value:
condition_values.append(condition_map[value['Condition']]) condition_values.append(condition_map[value["Condition"]])
else: else:
condition_values.append(clean_json(value, resources_map)) condition_values.append(clean_json(value, resources_map))
@ -343,36 +381,49 @@ def parse_condition(condition, resources_map, condition_map):
elif condition_operator == "Fn::Not": elif condition_operator == "Fn::Not":
return not parse_condition(condition_values[0], resources_map, condition_map) return not parse_condition(condition_values[0], resources_map, condition_map)
elif condition_operator == "Fn::And": elif condition_operator == "Fn::And":
return all([ return all(
parse_condition(condition_value, resources_map, condition_map) [
for condition_value parse_condition(condition_value, resources_map, condition_map)
in condition_values]) for condition_value in condition_values
]
)
elif condition_operator == "Fn::Or": elif condition_operator == "Fn::Or":
return any([ return any(
parse_condition(condition_value, resources_map, condition_map) [
for condition_value parse_condition(condition_value, resources_map, condition_map)
in condition_values]) for condition_value in condition_values
]
)
def parse_output(output_logical_id, output_json, resources_map): def parse_output(output_logical_id, output_json, resources_map):
output_json = clean_json(output_json, resources_map) output_json = clean_json(output_json, resources_map)
output = Output() output = Output()
output.key = output_logical_id output.key = output_logical_id
output.value = clean_json(output_json['Value'], resources_map) output.value = clean_json(output_json["Value"], resources_map)
output.description = output_json.get('Description') output.description = output_json.get("Description")
return output return output
class ResourceMap(collections.Mapping): class ResourceMap(collections_abc.Mapping):
""" """
This is a lazy loading map for resources. This allows us to create resources This is a lazy loading map for resources. This allows us to create resources
without needing to create a full dependency tree. Upon creation, each without needing to create a full dependency tree. Upon creation, each
each resources is passed this lazy map that it can grab dependencies from. each resources is passed this lazy map that it can grab dependencies from.
""" """
def __init__(self, stack_id, stack_name, parameters, tags, region_name, template, cross_stack_resources): def __init__(
self,
stack_id,
stack_name,
parameters,
tags,
region_name,
template,
cross_stack_resources,
):
self._template = template self._template = template
self._resource_json_map = template['Resources'] self._resource_json_map = template["Resources"]
self._region_name = region_name self._region_name = region_name
self.input_parameters = parameters self.input_parameters = parameters
self.tags = copy.deepcopy(tags) self.tags = copy.deepcopy(tags)
@ -381,7 +432,7 @@ class ResourceMap(collections.Mapping):
# Create the default resources # Create the default resources
self._parsed_resources = { self._parsed_resources = {
"AWS::AccountId": "123456789012", "AWS::AccountId": ACCOUNT_ID,
"AWS::Region": self._region_name, "AWS::Region": self._region_name,
"AWS::StackId": stack_id, "AWS::StackId": stack_id,
"AWS::StackName": stack_name, "AWS::StackName": stack_name,
@ -400,7 +451,8 @@ class ResourceMap(collections.Mapping):
if not resource_json: if not resource_json:
raise KeyError(resource_logical_id) raise KeyError(resource_logical_id)
new_resource = parse_and_create_resource( new_resource = parse_and_create_resource(
resource_logical_id, resource_json, self, self._region_name) resource_logical_id, resource_json, self, self._region_name
)
if new_resource is not None: if new_resource is not None:
self._parsed_resources[resource_logical_id] = new_resource self._parsed_resources[resource_logical_id] = new_resource
return new_resource return new_resource
@ -416,20 +468,27 @@ class ResourceMap(collections.Mapping):
return self._resource_json_map.keys() return self._resource_json_map.keys()
def load_mapping(self): def load_mapping(self):
self._parsed_resources.update(self._template.get('Mappings', {})) self._parsed_resources.update(self._template.get("Mappings", {}))
def load_parameters(self): def load_parameters(self):
parameter_slots = self._template.get('Parameters', {}) parameter_slots = self._template.get("Parameters", {})
for parameter_name, parameter in parameter_slots.items(): for parameter_name, parameter in parameter_slots.items():
# Set the default values. # Set the default values.
self.resolved_parameters[parameter_name] = parameter.get('Default') self.resolved_parameters[parameter_name] = parameter.get("Default")
# Set any input parameters that were passed # Set any input parameters that were passed
self.no_echo_parameter_keys = []
for key, value in self.input_parameters.items(): for key, value in self.input_parameters.items():
if key in self.resolved_parameters: if key in self.resolved_parameters:
value_type = parameter_slots[key].get('Type', 'String') parameter_slot = parameter_slots[key]
if value_type == 'CommaDelimitedList' or value_type.startswith("List"):
value = value.split(',') value_type = parameter_slot.get("Type", "String")
if value_type == "CommaDelimitedList" or value_type.startswith("List"):
value = value.split(",")
if parameter_slot.get("NoEcho"):
self.no_echo_parameter_keys.append(key)
self.resolved_parameters[key] = value self.resolved_parameters[key] = value
# Check if there are any non-default params that were not passed input # Check if there are any non-default params that were not passed input
@ -441,11 +500,15 @@ class ResourceMap(collections.Mapping):
self._parsed_resources.update(self.resolved_parameters) self._parsed_resources.update(self.resolved_parameters)
def load_conditions(self): def load_conditions(self):
conditions = self._template.get('Conditions', {}) conditions = self._template.get("Conditions", {})
self.lazy_condition_map = LazyDict() self.lazy_condition_map = LazyDict()
for condition_name, condition in conditions.items(): for condition_name, condition in conditions.items():
self.lazy_condition_map[condition_name] = functools.partial(parse_condition, self.lazy_condition_map[condition_name] = functools.partial(
condition, self._parsed_resources, self.lazy_condition_map) parse_condition,
condition,
self._parsed_resources,
self.lazy_condition_map,
)
for condition_name in self.lazy_condition_map: for condition_name in self.lazy_condition_map:
self.lazy_condition_map[condition_name] self.lazy_condition_map[condition_name]
@ -457,13 +520,18 @@ class ResourceMap(collections.Mapping):
# Since this is a lazy map, to create every object we just need to # Since this is a lazy map, to create every object we just need to
# iterate through self. # iterate through self.
self.tags.update({'aws:cloudformation:stack-name': self.get('AWS::StackName'), self.tags.update(
'aws:cloudformation:stack-id': self.get('AWS::StackId')}) {
"aws:cloudformation:stack-name": self.get("AWS::StackName"),
"aws:cloudformation:stack-id": self.get("AWS::StackId"),
}
)
for resource in self.resources: for resource in self.resources:
if isinstance(self[resource], ec2_models.TaggedEC2Resource): if isinstance(self[resource], ec2_models.TaggedEC2Resource):
self.tags['aws:cloudformation:logical-id'] = resource self.tags["aws:cloudformation:logical-id"] = resource
ec2_models.ec2_backends[self._region_name].create_tags( ec2_models.ec2_backends[self._region_name].create_tags(
[self[resource].physical_resource_id], self.tags) [self[resource].physical_resource_id], self.tags
)
def diff(self, template, parameters=None): def diff(self, template, parameters=None):
if parameters: if parameters:
@ -473,36 +541,35 @@ class ResourceMap(collections.Mapping):
self.load_conditions() self.load_conditions()
old_template = self._resource_json_map old_template = self._resource_json_map
new_template = template['Resources'] new_template = template["Resources"]
resource_names_by_action = { resource_names_by_action = {
'Add': set(new_template) - set(old_template), "Add": set(new_template) - set(old_template),
'Modify': set(name for name in new_template if name in old_template and new_template[ "Modify": set(
name] != old_template[name]), name
'Remove': set(old_template) - set(new_template) for name in new_template
} if name in old_template and new_template[name] != old_template[name]
resources_by_action = { ),
'Add': {}, "Remove": set(old_template) - set(new_template),
'Modify': {},
'Remove': {},
} }
resources_by_action = {"Add": {}, "Modify": {}, "Remove": {}}
for resource_name in resource_names_by_action['Add']: for resource_name in resource_names_by_action["Add"]:
resources_by_action['Add'][resource_name] = { resources_by_action["Add"][resource_name] = {
'LogicalResourceId': resource_name, "LogicalResourceId": resource_name,
'ResourceType': new_template[resource_name]['Type'] "ResourceType": new_template[resource_name]["Type"],
} }
for resource_name in resource_names_by_action['Modify']: for resource_name in resource_names_by_action["Modify"]:
resources_by_action['Modify'][resource_name] = { resources_by_action["Modify"][resource_name] = {
'LogicalResourceId': resource_name, "LogicalResourceId": resource_name,
'ResourceType': new_template[resource_name]['Type'] "ResourceType": new_template[resource_name]["Type"],
} }
for resource_name in resource_names_by_action['Remove']: for resource_name in resource_names_by_action["Remove"]:
resources_by_action['Remove'][resource_name] = { resources_by_action["Remove"][resource_name] = {
'LogicalResourceId': resource_name, "LogicalResourceId": resource_name,
'ResourceType': old_template[resource_name]['Type'] "ResourceType": old_template[resource_name]["Type"],
} }
return resources_by_action return resources_by_action
@ -511,35 +578,38 @@ class ResourceMap(collections.Mapping):
resources_by_action = self.diff(template, parameters) resources_by_action = self.diff(template, parameters)
old_template = self._resource_json_map old_template = self._resource_json_map
new_template = template['Resources'] new_template = template["Resources"]
self._resource_json_map = new_template self._resource_json_map = new_template
for resource_name, resource in resources_by_action['Add'].items(): for resource_name, resource in resources_by_action["Add"].items():
resource_json = new_template[resource_name] resource_json = new_template[resource_name]
new_resource = parse_and_create_resource( new_resource = parse_and_create_resource(
resource_name, resource_json, self, self._region_name) resource_name, resource_json, self, self._region_name
)
self._parsed_resources[resource_name] = new_resource self._parsed_resources[resource_name] = new_resource
for resource_name, resource in resources_by_action['Remove'].items(): for resource_name, resource in resources_by_action["Remove"].items():
resource_json = old_template[resource_name] resource_json = old_template[resource_name]
parse_and_delete_resource( parse_and_delete_resource(
resource_name, resource_json, self, self._region_name) resource_name, resource_json, self, self._region_name
)
self._parsed_resources.pop(resource_name) self._parsed_resources.pop(resource_name)
tries = 1 tries = 1
while resources_by_action['Modify'] and tries < 5: while resources_by_action["Modify"] and tries < 5:
for resource_name, resource in resources_by_action['Modify'].copy().items(): for resource_name, resource in resources_by_action["Modify"].copy().items():
resource_json = new_template[resource_name] resource_json = new_template[resource_name]
try: try:
changed_resource = parse_and_update_resource( changed_resource = parse_and_update_resource(
resource_name, resource_json, self, self._region_name) resource_name, resource_json, self, self._region_name
)
except Exception as e: except Exception as e:
# skip over dependency violations, and try again in a # skip over dependency violations, and try again in a
# second pass # second pass
last_exception = e last_exception = e
else: else:
self._parsed_resources[resource_name] = changed_resource self._parsed_resources[resource_name] = changed_resource
del resources_by_action['Modify'][resource_name] del resources_by_action["Modify"][resource_name]
tries += 1 tries += 1
if tries == 5: if tries == 5:
raise last_exception raise last_exception
@ -551,7 +621,7 @@ class ResourceMap(collections.Mapping):
for resource in remaining_resources.copy(): for resource in remaining_resources.copy():
parsed_resource = self._parsed_resources.get(resource) parsed_resource = self._parsed_resources.get(resource)
try: try:
if parsed_resource and hasattr(parsed_resource, 'delete'): if parsed_resource and hasattr(parsed_resource, "delete"):
parsed_resource.delete(self._region_name) parsed_resource.delete(self._region_name)
except Exception as e: except Exception as e:
# skip over dependency violations, and try again in a # skip over dependency violations, and try again in a
@ -564,12 +634,11 @@ class ResourceMap(collections.Mapping):
raise last_exception raise last_exception
class OutputMap(collections.Mapping): class OutputMap(collections_abc.Mapping):
def __init__(self, resources, template, stack_id): def __init__(self, resources, template, stack_id):
self._template = template self._template = template
self._stack_id = stack_id self._stack_id = stack_id
self._output_json_map = template.get('Outputs') self._output_json_map = template.get("Outputs")
# Create the default resources # Create the default resources
self._resource_map = resources self._resource_map = resources
@ -583,7 +652,8 @@ class OutputMap(collections.Mapping):
else: else:
output_json = self._output_json_map.get(output_logical_id) output_json = self._output_json_map.get(output_logical_id)
new_output = parse_output( new_output = parse_output(
output_logical_id, output_json, self._resource_map) output_logical_id, output_json, self._resource_map
)
self._parsed_outputs[output_logical_id] = new_output self._parsed_outputs[output_logical_id] = new_output
return new_output return new_output
@ -602,9 +672,11 @@ class OutputMap(collections.Mapping):
exports = [] exports = []
if self.outputs: if self.outputs:
for key, value in self._output_json_map.items(): for key, value in self._output_json_map.items():
if value.get('Export'): if value.get("Export"):
cleaned_name = clean_json(value['Export'].get('Name'), self._resource_map) cleaned_name = clean_json(
cleaned_value = clean_json(value.get('Value'), self._resource_map) value["Export"].get("Name"), self._resource_map
)
cleaned_value = clean_json(value.get("Value"), self._resource_map)
exports.append(Export(self._stack_id, cleaned_name, cleaned_value)) exports.append(Export(self._stack_id, cleaned_name, cleaned_value))
return exports return exports
@ -614,7 +686,6 @@ class OutputMap(collections.Mapping):
class Export(object): class Export(object):
def __init__(self, exporting_stack_id, name, value): def __init__(self, exporting_stack_id, name, value):
self._exporting_stack_id = exporting_stack_id self._exporting_stack_id = exporting_stack_id
self._name = name self._name = name

View File

@ -7,12 +7,12 @@ from six.moves.urllib.parse import urlparse
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id from moto.core.utils import amzn_request_id
from moto.s3 import s3_backend from moto.s3 import s3_backend
from moto.core import ACCOUNT_ID
from .models import cloudformation_backends from .models import cloudformation_backends
from .exceptions import ValidationError from .exceptions import ValidationError
class CloudFormationResponse(BaseResponse): class CloudFormationResponse(BaseResponse):
@property @property
def cloudformation_backend(self): def cloudformation_backend(self):
return cloudformation_backends[self.region] return cloudformation_backends[self.region]
@ -20,17 +20,18 @@ class CloudFormationResponse(BaseResponse):
def _get_stack_from_s3_url(self, template_url): def _get_stack_from_s3_url(self, template_url):
template_url_parts = urlparse(template_url) template_url_parts = urlparse(template_url)
if "localhost" in template_url: if "localhost" in template_url:
bucket_name, key_name = template_url_parts.path.lstrip( bucket_name, key_name = template_url_parts.path.lstrip("/").split("/", 1)
"/").split("/", 1)
else: else:
if template_url_parts.netloc.endswith('amazonaws.com') \ if template_url_parts.netloc.endswith(
and template_url_parts.netloc.startswith('s3'): "amazonaws.com"
) and template_url_parts.netloc.startswith("s3"):
# Handle when S3 url uses amazon url with bucket in path # Handle when S3 url uses amazon url with bucket in path
# Also handles getting region as technically s3 is region'd # Also handles getting region as technically s3 is region'd
# region = template_url.netloc.split('.')[1] # region = template_url.netloc.split('.')[1]
bucket_name, key_name = template_url_parts.path.lstrip( bucket_name, key_name = template_url_parts.path.lstrip("/").split(
"/").split("/", 1) "/", 1
)
else: else:
bucket_name = template_url_parts.netloc.split(".")[0] bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/") key_name = template_url_parts.path.lstrip("/")
@ -39,24 +40,26 @@ class CloudFormationResponse(BaseResponse):
return key.value.decode("utf-8") return key.value.decode("utf-8")
def create_stack(self): def create_stack(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
stack_body = self._get_param('TemplateBody') stack_body = self._get_param("TemplateBody")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
role_arn = self._get_param('RoleARN') role_arn = self._get_param("RoleARN")
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# Hack dict-comprehension # Hack dict-comprehension
parameters = dict([ parameters = dict(
(parameter['parameter_key'], parameter['parameter_value']) [
for parameter (parameter["parameter_key"], parameter["parameter_value"])
in parameters_list for parameter in parameters_list
]) ]
)
if template_url: if template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param( stack_notification_arns = self._get_multi_param("NotificationARNs.member")
'NotificationARNs.member')
stack = self.cloudformation_backend.create_stack( stack = self.cloudformation_backend.create_stack(
name=stack_name, name=stack_name,
@ -68,34 +71,37 @@ class CloudFormationResponse(BaseResponse):
role_arn=role_arn, role_arn=role_arn,
) )
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'CreateStackResponse': { {
'CreateStackResult': { "CreateStackResponse": {
'StackId': stack.stack_id, "CreateStackResult": {"StackId": stack.stack_id}
} }
} }
}) )
else: else:
template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE) template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE)
return template.render(stack=stack) return template.render(stack=stack)
@amzn_request_id @amzn_request_id
def create_change_set(self): def create_change_set(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
change_set_name = self._get_param('ChangeSetName') change_set_name = self._get_param("ChangeSetName")
stack_body = self._get_param('TemplateBody') stack_body = self._get_param("TemplateBody")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
role_arn = self._get_param('RoleARN') role_arn = self._get_param("RoleARN")
update_or_create = self._get_param('ChangeSetType', 'CREATE') update_or_create = self._get_param("ChangeSetType", "CREATE")
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
parameters = {param['parameter_key']: param['parameter_value'] for item in self._get_list_prefix("Tags.member")
for param in parameters_list} )
parameters = {
param["parameter_key"]: param["parameter_value"]
for param in parameters_list
}
if template_url: if template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param( stack_notification_arns = self._get_multi_param("NotificationARNs.member")
'NotificationARNs.member')
change_set_id, stack_id = self.cloudformation_backend.create_change_set( change_set_id, stack_id = self.cloudformation_backend.create_change_set(
stack_name=stack_name, stack_name=stack_name,
change_set_name=change_set_name, change_set_name=change_set_name,
@ -108,66 +114,64 @@ class CloudFormationResponse(BaseResponse):
change_set_type=update_or_create, change_set_type=update_or_create,
) )
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'CreateChangeSetResponse': { {
'CreateChangeSetResult': { "CreateChangeSetResponse": {
'Id': change_set_id, "CreateChangeSetResult": {
'StackId': stack_id, "Id": change_set_id,
"StackId": stack_id,
}
} }
} }
}) )
else: else:
template = self.response_template(CREATE_CHANGE_SET_RESPONSE_TEMPLATE) template = self.response_template(CREATE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render(stack_id=stack_id, change_set_id=change_set_id) return template.render(stack_id=stack_id, change_set_id=change_set_id)
def delete_change_set(self): def delete_change_set(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
change_set_name = self._get_param('ChangeSetName') change_set_name = self._get_param("ChangeSetName")
self.cloudformation_backend.delete_change_set(change_set_name=change_set_name, stack_name=stack_name) self.cloudformation_backend.delete_change_set(
change_set_name=change_set_name, stack_name=stack_name
)
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'DeleteChangeSetResponse': { {"DeleteChangeSetResponse": {"DeleteChangeSetResult": {}}}
'DeleteChangeSetResult': {}, )
}
})
else: else:
template = self.response_template(DELETE_CHANGE_SET_RESPONSE_TEMPLATE) template = self.response_template(DELETE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render() return template.render()
def describe_change_set(self): def describe_change_set(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
change_set_name = self._get_param('ChangeSetName') change_set_name = self._get_param("ChangeSetName")
change_set = self.cloudformation_backend.describe_change_set( change_set = self.cloudformation_backend.describe_change_set(
change_set_name=change_set_name, change_set_name=change_set_name, stack_name=stack_name
stack_name=stack_name,
) )
template = self.response_template(DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE) template = self.response_template(DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render(change_set=change_set) return template.render(change_set=change_set)
@amzn_request_id @amzn_request_id
def execute_change_set(self): def execute_change_set(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
change_set_name = self._get_param('ChangeSetName') change_set_name = self._get_param("ChangeSetName")
self.cloudformation_backend.execute_change_set( self.cloudformation_backend.execute_change_set(
stack_name=stack_name, stack_name=stack_name, change_set_name=change_set_name
change_set_name=change_set_name,
) )
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'ExecuteChangeSetResponse': { {"ExecuteChangeSetResponse": {"ExecuteChangeSetResult": {}}}
'ExecuteChangeSetResult': {}, )
}
})
else: else:
template = self.response_template(EXECUTE_CHANGE_SET_RESPONSE_TEMPLATE) template = self.response_template(EXECUTE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render() return template.render()
def describe_stacks(self): def describe_stacks(self):
stack_name_or_id = None stack_name_or_id = None
if self._get_param('StackName'): if self._get_param("StackName"):
stack_name_or_id = self.querystring.get('StackName')[0] stack_name_or_id = self.querystring.get("StackName")[0]
token = self._get_param('NextToken') token = self._get_param("NextToken")
stacks = self.cloudformation_backend.describe_stacks(stack_name_or_id) stacks = self.cloudformation_backend.describe_stacks(stack_name_or_id)
stack_ids = [stack.stack_id for stack in stacks] stack_ids = [stack.stack_id for stack in stacks]
if token: if token:
@ -175,7 +179,7 @@ class CloudFormationResponse(BaseResponse):
else: else:
start = 0 start = 0
max_results = 50 # using this to mske testing of paginated stacks more convenient than default 1 MB max_results = 50 # using this to mske testing of paginated stacks more convenient than default 1 MB
stacks_resp = stacks[start:start + max_results] stacks_resp = stacks[start : start + max_results]
next_token = None next_token = None
if len(stacks) > (start + max_results): if len(stacks) > (start + max_results):
next_token = stacks_resp[-1].stack_id next_token = stacks_resp[-1].stack_id
@ -183,9 +187,9 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacks=stacks_resp, next_token=next_token) return template.render(stacks=stacks_resp, next_token=next_token)
def describe_stack_resource(self): def describe_stack_resource(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
logical_resource_id = self._get_param('LogicalResourceId') logical_resource_id = self._get_param("LogicalResourceId")
for stack_resource in stack.stack_resources: for stack_resource in stack.stack_resources:
if stack_resource.logical_resource_id == logical_resource_id: if stack_resource.logical_resource_id == logical_resource_id:
@ -194,19 +198,18 @@ class CloudFormationResponse(BaseResponse):
else: else:
raise ValidationError(logical_resource_id) raise ValidationError(logical_resource_id)
template = self.response_template( template = self.response_template(DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
return template.render(stack=stack, resource=resource) return template.render(stack=stack, resource=resource)
def describe_stack_resources(self): def describe_stack_resources(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
template = self.response_template(DESCRIBE_STACK_RESOURCES_RESPONSE) template = self.response_template(DESCRIBE_STACK_RESOURCES_RESPONSE)
return template.render(stack=stack) return template.render(stack=stack)
def describe_stack_events(self): def describe_stack_events(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
template = self.response_template(DESCRIBE_STACK_EVENTS_RESPONSE) template = self.response_template(DESCRIBE_STACK_EVENTS_RESPONSE)
@ -223,68 +226,82 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacks=stacks) return template.render(stacks=stacks)
def list_stack_resources(self): def list_stack_resources(self):
stack_name_or_id = self._get_param('StackName') stack_name_or_id = self._get_param("StackName")
resources = self.cloudformation_backend.list_stack_resources( resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id)
stack_name_or_id)
template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE) template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE)
return template.render(resources=resources) return template.render(resources=resources)
def get_template(self): def get_template(self):
name_or_stack_id = self.querystring.get('StackName')[0] name_or_stack_id = self.querystring.get("StackName")[0]
stack = self.cloudformation_backend.get_stack(name_or_stack_id) stack = self.cloudformation_backend.get_stack(name_or_stack_id)
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
"GetTemplateResponse": { {
"GetTemplateResult": { "GetTemplateResponse": {
"TemplateBody": stack.template, "GetTemplateResult": {
"ResponseMetadata": { "TemplateBody": stack.template,
"RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" "ResponseMetadata": {
"RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE"
},
} }
} }
} }
}) )
else: else:
template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE) template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE)
return template.render(stack=stack) return template.render(stack=stack)
def update_stack(self): def update_stack(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
role_arn = self._get_param('RoleARN') role_arn = self._get_param("RoleARN")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
stack_body = self._get_param('TemplateBody') stack_body = self._get_param("TemplateBody")
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
if self._get_param('UsePreviousTemplate') == "true": if self._get_param("UsePreviousTemplate") == "true":
stack_body = stack.template stack_body = stack.template
elif not stack_body and template_url: elif not stack_body and template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
incoming_params = self._get_list_prefix("Parameters.member") incoming_params = self._get_list_prefix("Parameters.member")
parameters = dict([ parameters = dict(
(parameter['parameter_key'], parameter['parameter_value']) [
for parameter (parameter["parameter_key"], parameter["parameter_value"])
in incoming_params if 'parameter_value' in parameter for parameter in incoming_params
]) if "parameter_value" in parameter
previous = dict([ ]
(parameter['parameter_key'], stack.parameters[parameter['parameter_key']]) )
for parameter previous = dict(
in incoming_params if 'use_previous_value' in parameter [
]) (
parameter["parameter_key"],
stack.parameters[parameter["parameter_key"]],
)
for parameter in incoming_params
if "use_previous_value" in parameter
]
)
parameters.update(previous) parameters.update(previous)
# boto3 is supposed to let you clear the tags by passing an empty value, but the request body doesn't # boto3 is supposed to let you clear the tags by passing an empty value, but the request body doesn't
# end up containing anything we can use to differentiate between passing an empty value versus not # end up containing anything we can use to differentiate between passing an empty value versus not
# passing anything. so until that changes, moto won't be able to clear tags, only update them. # passing anything. so until that changes, moto won't be able to clear tags, only update them.
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# so that if we don't pass the parameter, we don't clear all the tags accidentally # so that if we don't pass the parameter, we don't clear all the tags accidentally
if not tags: if not tags:
tags = None tags = None
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
if stack.status == 'ROLLBACK_COMPLETE': if stack.status == "ROLLBACK_COMPLETE":
raise ValidationError( raise ValidationError(
stack.stack_id, message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(stack.stack_id)) stack.stack_id,
message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(
stack.stack_id
),
)
stack = self.cloudformation_backend.update_stack( stack = self.cloudformation_backend.update_stack(
name=stack_name, name=stack_name,
@ -295,11 +312,7 @@ class CloudFormationResponse(BaseResponse):
) )
if self.request_json: if self.request_json:
stack_body = { stack_body = {
'UpdateStackResponse': { "UpdateStackResponse": {"UpdateStackResult": {"StackId": stack.name}}
'UpdateStackResult': {
'StackId': stack.name,
}
}
} }
return json.dumps(stack_body) return json.dumps(stack_body)
else: else:
@ -307,56 +320,57 @@ class CloudFormationResponse(BaseResponse):
return template.render(stack=stack) return template.render(stack=stack)
def delete_stack(self): def delete_stack(self):
name_or_stack_id = self.querystring.get('StackName')[0] name_or_stack_id = self.querystring.get("StackName")[0]
self.cloudformation_backend.delete_stack(name_or_stack_id) self.cloudformation_backend.delete_stack(name_or_stack_id)
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps({"DeleteStackResponse": {"DeleteStackResult": {}}})
'DeleteStackResponse': {
'DeleteStackResult': {},
}
})
else: else:
template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE) template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE)
return template.render() return template.render()
def list_exports(self): def list_exports(self):
token = self._get_param('NextToken') token = self._get_param("NextToken")
exports, next_token = self.cloudformation_backend.list_exports(token=token) exports, next_token = self.cloudformation_backend.list_exports(token=token)
template = self.response_template(LIST_EXPORTS_RESPONSE) template = self.response_template(LIST_EXPORTS_RESPONSE)
return template.render(exports=exports, next_token=next_token) return template.render(exports=exports, next_token=next_token)
def validate_template(self): def validate_template(self):
cfn_lint = self.cloudformation_backend.validate_template(self._get_param('TemplateBody')) cfn_lint = self.cloudformation_backend.validate_template(
self._get_param("TemplateBody")
)
if cfn_lint: if cfn_lint:
raise ValidationError(cfn_lint[0].message) raise ValidationError(cfn_lint[0].message)
description = "" description = ""
try: try:
description = json.loads(self._get_param('TemplateBody'))['Description'] description = json.loads(self._get_param("TemplateBody"))["Description"]
except (ValueError, KeyError): except (ValueError, KeyError):
pass pass
try: try:
description = yaml.load(self._get_param('TemplateBody'))['Description'] description = yaml.load(self._get_param("TemplateBody"))["Description"]
except (yaml.ParserError, KeyError): except (yaml.ParserError, KeyError):
pass pass
template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE) template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE)
return template.render(description=description) return template.render(description=description)
def create_stack_set(self): def create_stack_set(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
stack_body = self._get_param('TemplateBody') stack_body = self._get_param("TemplateBody")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
# role_arn = self._get_param('RoleARN') # role_arn = self._get_param('RoleARN')
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# Copy-Pasta - Hack dict-comprehension # Copy-Pasta - Hack dict-comprehension
parameters = dict([ parameters = dict(
(parameter['parameter_key'], parameter['parameter_value']) [
for parameter (parameter["parameter_key"], parameter["parameter_value"])
in parameters_list for parameter in parameters_list
]) ]
)
if template_url: if template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
@ -368,59 +382,67 @@ class CloudFormationResponse(BaseResponse):
# role_arn=role_arn, # role_arn=role_arn,
) )
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'CreateStackSetResponse': { {
'CreateStackSetResult': { "CreateStackSetResponse": {
'StackSetId': stackset.stackset_id, "CreateStackSetResult": {"StackSetId": stackset.stackset_id}
} }
} }
}) )
else: else:
template = self.response_template(CREATE_STACK_SET_RESPONSE_TEMPLATE) template = self.response_template(CREATE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(stackset=stackset) return template.render(stackset=stackset)
def create_stack_instances(self): def create_stack_instances(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param('Accounts.member') accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param('Regions.member') regions = self._get_multi_param("Regions.member")
parameters = self._get_multi_param('ParameterOverrides.member') parameters = self._get_multi_param("ParameterOverrides.member")
self.cloudformation_backend.create_stack_instances(stackset_name, accounts, regions, parameters) self.cloudformation_backend.create_stack_instances(
stackset_name, accounts, regions, parameters
)
template = self.response_template(CREATE_STACK_INSTANCES_TEMPLATE) template = self.response_template(CREATE_STACK_INSTANCES_TEMPLATE)
return template.render() return template.render()
def delete_stack_set(self): def delete_stack_set(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
self.cloudformation_backend.delete_stack_set(stackset_name) self.cloudformation_backend.delete_stack_set(stackset_name)
template = self.response_template(DELETE_STACK_SET_RESPONSE_TEMPLATE) template = self.response_template(DELETE_STACK_SET_RESPONSE_TEMPLATE)
return template.render() return template.render()
def delete_stack_instances(self): def delete_stack_instances(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param('Accounts.member') accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param('Regions.member') regions = self._get_multi_param("Regions.member")
operation = self.cloudformation_backend.delete_stack_instances(stackset_name, accounts, regions) operation = self.cloudformation_backend.delete_stack_instances(
stackset_name, accounts, regions
)
template = self.response_template(DELETE_STACK_INSTANCES_TEMPLATE) template = self.response_template(DELETE_STACK_INSTANCES_TEMPLATE)
return template.render(operation=operation) return template.render(operation=operation)
def describe_stack_set(self): def describe_stack_set(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
if not stackset.admin_role: if not stackset.admin_role:
stackset.admin_role = 'arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole' stackset.admin_role = "arn:aws:iam::{AccountId}:role/AWSCloudFormationStackSetAdministrationRole".format(
AccountId=ACCOUNT_ID
)
if not stackset.execution_role: if not stackset.execution_role:
stackset.execution_role = 'AWSCloudFormationStackSetExecutionRole' stackset.execution_role = "AWSCloudFormationStackSetExecutionRole"
template = self.response_template(DESCRIBE_STACK_SET_RESPONSE_TEMPLATE) template = self.response_template(DESCRIBE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(stackset=stackset) return template.render(stackset=stackset)
def describe_stack_instance(self): def describe_stack_instance(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
account = self._get_param('StackInstanceAccount') account = self._get_param("StackInstanceAccount")
region = self._get_param('StackInstanceRegion') region = self._get_param("StackInstanceRegion")
instance = self.cloudformation_backend.get_stack_set(stackset_name).instances.get_instance(account, region) instance = self.cloudformation_backend.get_stack_set(
stackset_name
).instances.get_instance(account, region)
template = self.response_template(DESCRIBE_STACK_INSTANCE_TEMPLATE) template = self.response_template(DESCRIBE_STACK_INSTANCE_TEMPLATE)
rendered = template.render(instance=instance) rendered = template.render(instance=instance)
return rendered return rendered
@ -431,61 +453,66 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacksets=stacksets) return template.render(stacksets=stacksets)
def list_stack_instances(self): def list_stack_instances(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
template = self.response_template(LIST_STACK_INSTANCES_TEMPLATE) template = self.response_template(LIST_STACK_INSTANCES_TEMPLATE)
return template.render(stackset=stackset) return template.render(stackset=stackset)
def list_stack_set_operations(self): def list_stack_set_operations(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
template = self.response_template(LIST_STACK_SET_OPERATIONS_RESPONSE_TEMPLATE) template = self.response_template(LIST_STACK_SET_OPERATIONS_RESPONSE_TEMPLATE)
return template.render(stackset=stackset) return template.render(stackset=stackset)
def stop_stack_set_operation(self): def stop_stack_set_operation(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
operation_id = self._get_param('OperationId') operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
stackset.update_operation(operation_id, 'STOPPED') stackset.update_operation(operation_id, "STOPPED")
template = self.response_template(STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE) template = self.response_template(STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE)
return template.render() return template.render()
def describe_stack_set_operation(self): def describe_stack_set_operation(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
operation_id = self._get_param('OperationId') operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
operation = stackset.get_operation(operation_id) operation = stackset.get_operation(operation_id)
template = self.response_template(DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE) template = self.response_template(DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE)
return template.render(stackset=stackset, operation=operation) return template.render(stackset=stackset, operation=operation)
def list_stack_set_operation_results(self): def list_stack_set_operation_results(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
operation_id = self._get_param('OperationId') operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
operation = stackset.get_operation(operation_id) operation = stackset.get_operation(operation_id)
template = self.response_template(LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE) template = self.response_template(
LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE
)
return template.render(operation=operation) return template.render(operation=operation)
def update_stack_set(self): def update_stack_set(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
operation_id = self._get_param('OperationId') operation_id = self._get_param("OperationId")
description = self._get_param('Description') description = self._get_param("Description")
execution_role = self._get_param('ExecutionRoleName') execution_role = self._get_param("ExecutionRoleName")
admin_role = self._get_param('AdministrationRoleARN') admin_role = self._get_param("AdministrationRoleARN")
accounts = self._get_multi_param('Accounts.member') accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param('Regions.member') regions = self._get_multi_param("Regions.member")
template_body = self._get_param('TemplateBody') template_body = self._get_param("TemplateBody")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
if template_url: if template_url:
template_body = self._get_stack_from_s3_url(template_url) template_body = self._get_stack_from_s3_url(template_url)
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
parameters = dict([ parameters = dict(
(parameter['parameter_key'], parameter['parameter_value']) [
for parameter (parameter["parameter_key"], parameter["parameter_value"])
in parameters_list for parameter in parameters_list
]) ]
)
operation = self.cloudformation_backend.update_stack_set( operation = self.cloudformation_backend.update_stack_set(
stackset_name=stackset_name, stackset_name=stackset_name,
template=template_body, template=template_body,
@ -496,18 +523,20 @@ class CloudFormationResponse(BaseResponse):
execution_role=execution_role, execution_role=execution_role,
accounts=accounts, accounts=accounts,
regions=regions, regions=regions,
operation_id=operation_id operation_id=operation_id,
) )
template = self.response_template(UPDATE_STACK_SET_RESPONSE_TEMPLATE) template = self.response_template(UPDATE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(operation=operation) return template.render(operation=operation)
def update_stack_instances(self): def update_stack_instances(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param('Accounts.member') accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param('Regions.member') regions = self._get_multi_param("Regions.member")
parameters = self._get_multi_param('ParameterOverrides.member') parameters = self._get_multi_param("ParameterOverrides.member")
operation = self.cloudformation_backend.get_stack_set(stackset_name).update_instances(accounts, regions, parameters) operation = self.cloudformation_backend.get_stack_set(
stackset_name
).update_instances(accounts, regions, parameters)
template = self.response_template(UPDATE_STACK_INSTANCES_RESPONSE_TEMPLATE) template = self.response_template(UPDATE_STACK_INSTANCES_RESPONSE_TEMPLATE)
return template.render(operation=operation) return template.render(operation=operation)
@ -654,7 +683,11 @@ DESCRIBE_STACKS_TEMPLATE = """<DescribeStacksResponse>
{% for param_name, param_value in stack.stack_parameters.items() %} {% for param_name, param_value in stack.stack_parameters.items() %}
<member> <member>
<ParameterKey>{{ param_name }}</ParameterKey> <ParameterKey>{{ param_name }}</ParameterKey>
<ParameterValue>{{ param_value }}</ParameterValue> {% if param_name in stack.resource_map.no_echo_parameter_keys %}
<ParameterValue>****</ParameterValue>
{% else %}
<ParameterValue>{{ param_value }}</ParameterValue>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</Parameters> </Parameters>
@ -1021,11 +1054,14 @@ STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE = """<StopStackSetOperationResponse x
</ResponseMetadata> </StopStackSetOperationResponse> </ResponseMetadata> </StopStackSetOperationResponse>
""" """
DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = """<DescribeStackSetOperationResponse xmlns="http://internal.amazon.com/coral/com.amazonaws.maestro.service.v20160713/"> DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = (
"""<DescribeStackSetOperationResponse xmlns="http://internal.amazon.com/coral/com.amazonaws.maestro.service.v20160713/">
<DescribeStackSetOperationResult> <DescribeStackSetOperationResult>
<StackSetOperation> <StackSetOperation>
<ExecutionRoleName>{{ stackset.execution_role }}</ExecutionRoleName> <ExecutionRoleName>{{ stackset.execution_role }}</ExecutionRoleName>
<AdministrationRoleARN>arn:aws:iam::123456789012:role/{{ stackset.admin_role }}</AdministrationRoleARN> <AdministrationRoleARN>arn:aws:iam::"""
+ ACCOUNT_ID
+ """:role/{{ stackset.admin_role }}</AdministrationRoleARN>
<StackSetId>{{ stackset.id }}</StackSetId> <StackSetId>{{ stackset.id }}</StackSetId>
<CreationTimestamp>{{ operation.CreationTimestamp }}</CreationTimestamp> <CreationTimestamp>{{ operation.CreationTimestamp }}</CreationTimestamp>
<OperationId>{{ operation.OperationId }}</OperationId> <OperationId>{{ operation.OperationId }}</OperationId>
@ -1042,15 +1078,19 @@ DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = """<DescribeStackSetOperationRes
</ResponseMetadata> </ResponseMetadata>
</DescribeStackSetOperationResponse> </DescribeStackSetOperationResponse>
""" """
)
LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = """<ListStackSetOperationResultsResponse xmlns="http://internal.amazon.com/coral/com.amazonaws.maestro.service.v20160713/"> LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = (
"""<ListStackSetOperationResultsResponse xmlns="http://internal.amazon.com/coral/com.amazonaws.maestro.service.v20160713/">
<ListStackSetOperationResultsResult> <ListStackSetOperationResultsResult>
<Summaries> <Summaries>
{% for instance in operation.Instances %} {% for instance in operation.Instances %}
{% for account, region in instance.items() %} {% for account, region in instance.items() %}
<member> <member>
<AccountGateResult> <AccountGateResult>
<StatusReason>Function not found: arn:aws:lambda:us-west-2:123456789012:function:AWSCloudFormationStackSetAccountGate</StatusReason> <StatusReason>Function not found: arn:aws:lambda:us-west-2:"""
+ ACCOUNT_ID
+ """:function:AWSCloudFormationStackSetAccountGate</StatusReason>
<Status>SKIPPED</Status> <Status>SKIPPED</Status>
</AccountGateResult> </AccountGateResult>
<Region>{{ region }}</Region> <Region>{{ region }}</Region>
@ -1066,3 +1106,4 @@ LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = """<ListStackSetOperationRe
</ResponseMetadata> </ResponseMetadata>
</ListStackSetOperationResultsResponse> </ListStackSetOperationResultsResponse>
""" """
)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import CloudFormationResponse from .responses import CloudFormationResponse
url_bases = [ url_bases = ["https?://cloudformation.(.+).amazonaws.com"]
"https?://cloudformation.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": CloudFormationResponse.dispatch}
'{0}/$': CloudFormationResponse.dispatch,
}

View File

@ -4,50 +4,59 @@ import six
import random import random
import yaml import yaml
import os import os
import string
from cfnlint import decode, core from cfnlint import decode, core
from moto.core import ACCOUNT_ID
def generate_stack_id(stack_name, region="us-east-1", account="123456789"): def generate_stack_id(stack_name, region="us-east-1", account="123456789"):
random_id = uuid.uuid4() random_id = uuid.uuid4()
return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(region, account, stack_name, random_id) return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(
region, account, stack_name, random_id
)
def generate_changeset_id(changeset_name, region_name): def generate_changeset_id(changeset_name, region_name):
random_id = uuid.uuid4() random_id = uuid.uuid4()
return 'arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}'.format(region_name, changeset_name, random_id) return "arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}".format(
region_name, changeset_name, random_id
)
def generate_stackset_id(stackset_name): def generate_stackset_id(stackset_name):
random_id = uuid.uuid4() random_id = uuid.uuid4()
return '{}:{}'.format(stackset_name, random_id) return "{}:{}".format(stackset_name, random_id)
def generate_stackset_arn(stackset_id, region_name): def generate_stackset_arn(stackset_id, region_name):
return 'arn:aws:cloudformation:{}:123456789012:stackset/{}'.format(region_name, stackset_id) return "arn:aws:cloudformation:{}:{}:stackset/{}".format(
region_name, ACCOUNT_ID, stackset_id
)
def random_suffix(): def random_suffix():
size = 12 size = 12
chars = list(range(10)) + ['A-Z'] chars = list(range(10)) + list(string.ascii_uppercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size)) return "".join(six.text_type(random.choice(chars)) for x in range(size))
def yaml_tag_constructor(loader, tag, node): def yaml_tag_constructor(loader, tag, node):
"""convert shorthand intrinsic function to full name """convert shorthand intrinsic function to full name
""" """
def _f(loader, tag, node): def _f(loader, tag, node):
if tag == '!GetAtt': if tag == "!GetAtt":
return node.value.split('.') return node.value.split(".")
elif type(node) == yaml.SequenceNode: elif type(node) == yaml.SequenceNode:
return loader.construct_sequence(node) return loader.construct_sequence(node)
else: else:
return node.value return node.value
if tag == '!Ref': if tag == "!Ref":
key = 'Ref' key = "Ref"
else: else:
key = 'Fn::{}'.format(tag[1:]) key = "Fn::{}".format(tag[1:])
return {key: _f(loader, tag, node)} return {key: _f(loader, tag, node)}
@ -70,13 +79,9 @@ def validate_template_cfn_lint(template):
rules = core.get_rules([], [], []) rules = core.get_rules([], [], [])
# Use us-east-1 region (spec file) for validation # Use us-east-1 region (spec file) for validation
regions = ['us-east-1'] regions = ["us-east-1"]
# Process all the rules and gather the errors # Process all the rules and gather the errors
matches = core.run_checks( matches = core.run_checks(abs_filename, template, rules, regions)
abs_filename,
template,
rules,
regions)
return matches return matches

View File

@ -1,6 +1,6 @@
from .models import cloudwatch_backends from .models import cloudwatch_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
cloudwatch_backend = cloudwatch_backends['us-east-1'] cloudwatch_backend = cloudwatch_backends["us-east-1"]
mock_cloudwatch = base_decorator(cloudwatch_backends) mock_cloudwatch = base_decorator(cloudwatch_backends)
mock_cloudwatch_deprecated = deprecated_base_decorator(cloudwatch_backends) mock_cloudwatch_deprecated = deprecated_base_decorator(cloudwatch_backends)

View File

@ -1,20 +1,21 @@
import json import json
from boto3 import Session
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
import boto.ec2.cloudwatch
from datetime import datetime, timedelta from datetime import datetime, timedelta
from dateutil.tz import tzutc from dateutil.tz import tzutc
from uuid import uuid4
from .utils import make_arn_for_dashboard from .utils import make_arn_for_dashboard
DEFAULT_ACCOUNT_ID = 123456789012 from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
_EMPTY_LIST = tuple() _EMPTY_LIST = tuple()
class Dimension(object): class Dimension(object):
def __init__(self, name, value): def __init__(self, name, value):
self.name = name self.name = name
self.value = value self.value = value
@ -49,10 +50,23 @@ def daterange(start, stop, step=timedelta(days=1), inclusive=False):
class FakeAlarm(BaseModel): class FakeAlarm(BaseModel):
def __init__(
def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods, self,
period, threshold, statistic, description, dimensions, alarm_actions, name,
ok_actions, insufficient_data_actions, unit): namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
):
self.name = name self.name = name
self.namespace = namespace self.namespace = namespace
self.metric_name = metric_name self.metric_name = metric_name
@ -62,8 +76,9 @@ class FakeAlarm(BaseModel):
self.threshold = threshold self.threshold = threshold
self.statistic = statistic self.statistic = statistic
self.description = description self.description = description
self.dimensions = [Dimension(dimension['name'], dimension[ self.dimensions = [
'value']) for dimension in dimensions] Dimension(dimension["name"], dimension["value"]) for dimension in dimensions
]
self.alarm_actions = alarm_actions self.alarm_actions = alarm_actions
self.ok_actions = ok_actions self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions self.insufficient_data_actions = insufficient_data_actions
@ -72,15 +87,21 @@ class FakeAlarm(BaseModel):
self.history = [] self.history = []
self.state_reason = '' self.state_reason = ""
self.state_reason_data = '{}' self.state_reason_data = "{}"
self.state_value = 'OK' self.state_value = "OK"
self.state_updated_timestamp = datetime.utcnow() self.state_updated_timestamp = datetime.utcnow()
def update_state(self, reason, reason_data, state_value): def update_state(self, reason, reason_data, state_value):
# History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action # History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action
self.history.append( self.history.append(
('StateUpdate', self.state_reason, self.state_reason_data, self.state_value, self.state_updated_timestamp) (
"StateUpdate",
self.state_reason,
self.state_reason_data,
self.state_value,
self.state_updated_timestamp,
)
) )
self.state_reason = reason self.state_reason = reason
@ -90,14 +111,14 @@ class FakeAlarm(BaseModel):
class MetricDatum(BaseModel): class MetricDatum(BaseModel):
def __init__(self, namespace, name, value, dimensions, timestamp): def __init__(self, namespace, name, value, dimensions, timestamp):
self.namespace = namespace self.namespace = namespace
self.name = name self.name = name
self.value = value self.value = value
self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc()) self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc())
self.dimensions = [Dimension(dimension['Name'], dimension[ self.dimensions = [
'Value']) for dimension in dimensions] Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions
]
class Dashboard(BaseModel): class Dashboard(BaseModel):
@ -120,7 +141,7 @@ class Dashboard(BaseModel):
return len(self.body) return len(self.body)
def __repr__(self): def __repr__(self):
return '<CloudWatchDashboard {0}>'.format(self.name) return "<CloudWatchDashboard {0}>".format(self.name)
class Statistics: class Statistics:
@ -131,7 +152,7 @@ class Statistics:
@property @property
def sample_count(self): def sample_count(self):
if 'SampleCount' not in self.stats: if "SampleCount" not in self.stats:
return None return None
return len(self.values) return len(self.values)
@ -142,28 +163,28 @@ class Statistics:
@property @property
def sum(self): def sum(self):
if 'Sum' not in self.stats: if "Sum" not in self.stats:
return None return None
return sum(self.values) return sum(self.values)
@property @property
def minimum(self): def minimum(self):
if 'Minimum' not in self.stats: if "Minimum" not in self.stats:
return None return None
return min(self.values) return min(self.values)
@property @property
def maximum(self): def maximum(self):
if 'Maximum' not in self.stats: if "Maximum" not in self.stats:
return None return None
return max(self.values) return max(self.values)
@property @property
def average(self): def average(self):
if 'Average' not in self.stats: if "Average" not in self.stats:
return None return None
# when moto is 3.4+ we can switch to the statistics module # when moto is 3.4+ we can switch to the statistics module
@ -171,18 +192,45 @@ class Statistics:
class CloudWatchBackend(BaseBackend): class CloudWatchBackend(BaseBackend):
def __init__(self): def __init__(self):
self.alarms = {} self.alarms = {}
self.dashboards = {} self.dashboards = {}
self.metric_data = [] self.metric_data = []
self.paged_metric_data = {}
def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods, def put_metric_alarm(
period, threshold, statistic, description, dimensions, self,
alarm_actions, ok_actions, insufficient_data_actions, unit): name,
alarm = FakeAlarm(name, namespace, metric_name, comparison_operator, evaluation_periods, period, namespace,
threshold, statistic, description, dimensions, alarm_actions, metric_name,
ok_actions, insufficient_data_actions, unit) comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
):
alarm = FakeAlarm(
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
)
self.alarms[name] = alarm self.alarms[name] = alarm
return alarm return alarm
@ -214,14 +262,12 @@ class CloudWatchBackend(BaseBackend):
] ]
def get_alarms_by_alarm_names(self, alarm_names): def get_alarms_by_alarm_names(self, alarm_names):
return [ return [alarm for alarm in self.alarms.values() if alarm.name in alarm_names]
alarm
for alarm in self.alarms.values()
if alarm.name in alarm_names
]
def get_alarms_by_state_value(self, target_state): def get_alarms_by_state_value(self, target_state):
return filter(lambda alarm: alarm.state_value == target_state, self.alarms.values()) return filter(
lambda alarm: alarm.state_value == target_state, self.alarms.values()
)
def delete_alarms(self, alarm_names): def delete_alarms(self, alarm_names):
for alarm_name in alarm_names: for alarm_name in alarm_names:
@ -230,17 +276,31 @@ class CloudWatchBackend(BaseBackend):
def put_metric_data(self, namespace, metric_data): def put_metric_data(self, namespace, metric_data):
for metric_member in metric_data: for metric_member in metric_data:
# Preserve "datetime" for get_metric_statistics comparisons # Preserve "datetime" for get_metric_statistics comparisons
timestamp = metric_member.get('Timestamp') timestamp = metric_member.get("Timestamp")
if timestamp is not None and type(timestamp) != datetime: if timestamp is not None and type(timestamp) != datetime:
timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
timestamp = timestamp.replace(tzinfo=tzutc()) timestamp = timestamp.replace(tzinfo=tzutc())
self.metric_data.append(MetricDatum( self.metric_data.append(
namespace, metric_member['MetricName'], float(metric_member.get('Value', 0)), metric_member.get('Dimensions.member', _EMPTY_LIST), timestamp)) MetricDatum(
namespace,
metric_member["MetricName"],
float(metric_member.get("Value", 0)),
metric_member.get("Dimensions.member", _EMPTY_LIST),
timestamp,
)
)
def get_metric_statistics(self, namespace, metric_name, start_time, end_time, period, stats): def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats
):
period_delta = timedelta(seconds=period) period_delta = timedelta(seconds=period)
filtered_data = [md for md in self.metric_data if filtered_data = [
md.namespace == namespace and md.name == metric_name and start_time <= md.timestamp <= end_time] md
for md in self.metric_data
if md.namespace == namespace
and md.name == metric_name
and start_time <= md.timestamp <= end_time
]
# earliest to oldest # earliest to oldest
filtered_data = sorted(filtered_data, key=lambda x: x.timestamp) filtered_data = sorted(filtered_data, key=lambda x: x.timestamp)
@ -249,9 +309,15 @@ class CloudWatchBackend(BaseBackend):
idx = 0 idx = 0
data = list() data = list()
for dt in daterange(filtered_data[0].timestamp, filtered_data[-1].timestamp + period_delta, period_delta): for dt in daterange(
filtered_data[0].timestamp,
filtered_data[-1].timestamp + period_delta,
period_delta,
):
s = Statistics(stats, dt) s = Statistics(stats, dt)
while idx < len(filtered_data) and filtered_data[idx].timestamp < (dt + period_delta): while idx < len(filtered_data) and filtered_data[idx].timestamp < (
dt + period_delta
):
s.values.append(filtered_data[idx].value) s.values.append(filtered_data[idx].value)
idx += 1 idx += 1
@ -268,7 +334,7 @@ class CloudWatchBackend(BaseBackend):
def put_dashboard(self, name, body): def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body) self.dashboards[name] = Dashboard(name, body)
def list_dashboards(self, prefix=''): def list_dashboards(self, prefix=""):
for key, value in self.dashboards.items(): for key, value in self.dashboards.items():
if key.startswith(prefix): if key.startswith(prefix):
yield value yield value
@ -280,7 +346,12 @@ class CloudWatchBackend(BaseBackend):
left_over = to_delete - all_dashboards left_over = to_delete - all_dashboards
if len(left_over) > 0: if len(left_over) > 0:
# Some dashboards are not found # Some dashboards are not found
return False, 'The specified dashboard does not exist. [{0}]'.format(', '.join(left_over)) return (
False,
"The specified dashboard does not exist. [{0}]".format(
", ".join(left_over)
),
)
for dashboard in to_delete: for dashboard in to_delete:
del self.dashboards[dashboard] del self.dashboards[dashboard]
@ -295,32 +366,66 @@ class CloudWatchBackend(BaseBackend):
if reason_data is not None: if reason_data is not None:
json.loads(reason_data) json.loads(reason_data)
except ValueError: except ValueError:
raise RESTError('InvalidFormat', 'StateReasonData is invalid JSON') raise RESTError("InvalidFormat", "StateReasonData is invalid JSON")
if alarm_name not in self.alarms: if alarm_name not in self.alarms:
raise RESTError('ResourceNotFound', 'Alarm {0} not found'.format(alarm_name), status=404) raise RESTError(
"ResourceNotFound", "Alarm {0} not found".format(alarm_name), status=404
)
if state_value not in ('OK', 'ALARM', 'INSUFFICIENT_DATA'): if state_value not in ("OK", "ALARM", "INSUFFICIENT_DATA"):
raise RESTError('InvalidParameterValue', 'StateValue is not one of OK | ALARM | INSUFFICIENT_DATA') raise RESTError(
"InvalidParameterValue",
"StateValue is not one of OK | ALARM | INSUFFICIENT_DATA",
)
self.alarms[alarm_name].update_state(reason, reason_data, state_value) self.alarms[alarm_name].update_state(reason, reason_data, state_value)
def list_metrics(self, next_token, namespace, metric_name):
if next_token:
if next_token not in self.paged_metric_data:
raise RESTError(
"PaginationException", "Request parameter NextToken is invalid"
)
else:
metrics = self.paged_metric_data[next_token]
del self.paged_metric_data[next_token] # Cant reuse same token twice
return self._get_paginated(metrics)
else:
metrics = self.get_filtered_metrics(metric_name, namespace)
return self._get_paginated(metrics)
def get_filtered_metrics(self, metric_name, namespace):
metrics = self.get_all_metrics()
if namespace:
metrics = [md for md in metrics if md.namespace == namespace]
if metric_name:
metrics = [md for md in metrics if md.name == metric_name]
return metrics
def _get_paginated(self, metrics):
if len(metrics) > 500:
next_token = str(uuid4())
self.paged_metric_data[next_token] = metrics[500:]
return next_token, metrics[0:500]
else:
return None, metrics
class LogGroup(BaseModel): class LogGroup(BaseModel):
def __init__(self, spec): def __init__(self, spec):
# required # required
self.name = spec['LogGroupName'] self.name = spec["LogGroupName"]
# optional # optional
self.tags = spec.get('Tags', []) self.tags = spec.get("Tags", [])
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(
properties = cloudformation_json['Properties'] cls, resource_name, cloudformation_json, region_name
spec = { ):
'LogGroupName': properties['LogGroupName'] properties = cloudformation_json["Properties"]
} spec = {"LogGroupName": properties["LogGroupName"]}
optional_properties = 'Tags'.split() optional_properties = "Tags".split()
for prop in optional_properties: for prop in optional_properties:
if prop in properties: if prop in properties:
spec[prop] = properties[prop] spec[prop] = properties[prop]
@ -328,5 +433,11 @@ class LogGroup(BaseModel):
cloudwatch_backends = {} cloudwatch_backends = {}
for region in boto.ec2.cloudwatch.regions(): for region in Session().get_available_regions("cloudwatch"):
cloudwatch_backends[region.name] = CloudWatchBackend() cloudwatch_backends[region] = CloudWatchBackend()
for region in Session().get_available_regions(
"cloudwatch", partition_name="aws-us-gov"
):
cloudwatch_backends[region] = CloudWatchBackend()
for region in Session().get_available_regions("cloudwatch", partition_name="aws-cn"):
cloudwatch_backends[region] = CloudWatchBackend()

View File

@ -6,7 +6,6 @@ from dateutil.parser import parse as dtparse
class CloudWatchResponse(BaseResponse): class CloudWatchResponse(BaseResponse):
@property @property
def cloudwatch_backend(self): def cloudwatch_backend(self):
return cloudwatch_backends[self.region] return cloudwatch_backends[self.region]
@ -17,45 +16,54 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def put_metric_alarm(self): def put_metric_alarm(self):
name = self._get_param('AlarmName') name = self._get_param("AlarmName")
namespace = self._get_param('Namespace') namespace = self._get_param("Namespace")
metric_name = self._get_param('MetricName') metric_name = self._get_param("MetricName")
comparison_operator = self._get_param('ComparisonOperator') comparison_operator = self._get_param("ComparisonOperator")
evaluation_periods = self._get_param('EvaluationPeriods') evaluation_periods = self._get_param("EvaluationPeriods")
period = self._get_param('Period') period = self._get_param("Period")
threshold = self._get_param('Threshold') threshold = self._get_param("Threshold")
statistic = self._get_param('Statistic') statistic = self._get_param("Statistic")
description = self._get_param('AlarmDescription') description = self._get_param("AlarmDescription")
dimensions = self._get_list_prefix('Dimensions.member') dimensions = self._get_list_prefix("Dimensions.member")
alarm_actions = self._get_multi_param('AlarmActions.member') alarm_actions = self._get_multi_param("AlarmActions.member")
ok_actions = self._get_multi_param('OKActions.member') ok_actions = self._get_multi_param("OKActions.member")
insufficient_data_actions = self._get_multi_param( insufficient_data_actions = self._get_multi_param(
"InsufficientDataActions.member") "InsufficientDataActions.member"
unit = self._get_param('Unit') )
alarm = self.cloudwatch_backend.put_metric_alarm(name, namespace, metric_name, unit = self._get_param("Unit")
comparison_operator, alarm = self.cloudwatch_backend.put_metric_alarm(
evaluation_periods, period, name,
threshold, statistic, namespace,
description, dimensions, metric_name,
alarm_actions, ok_actions, comparison_operator,
insufficient_data_actions, evaluation_periods,
unit) period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
)
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE) template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm) return template.render(alarm=alarm)
@amzn_request_id @amzn_request_id
def describe_alarms(self): def describe_alarms(self):
action_prefix = self._get_param('ActionPrefix') action_prefix = self._get_param("ActionPrefix")
alarm_name_prefix = self._get_param('AlarmNamePrefix') alarm_name_prefix = self._get_param("AlarmNamePrefix")
alarm_names = self._get_multi_param('AlarmNames.member') alarm_names = self._get_multi_param("AlarmNames.member")
state_value = self._get_param('StateValue') state_value = self._get_param("StateValue")
if action_prefix: if action_prefix:
alarms = self.cloudwatch_backend.get_alarms_by_action_prefix( alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(action_prefix)
action_prefix)
elif alarm_name_prefix: elif alarm_name_prefix:
alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix( alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix(
alarm_name_prefix) alarm_name_prefix
)
elif alarm_names: elif alarm_names:
alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names) alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names)
elif state_value: elif state_value:
@ -68,15 +76,15 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def delete_alarms(self): def delete_alarms(self):
alarm_names = self._get_multi_param('AlarmNames.member') alarm_names = self._get_multi_param("AlarmNames.member")
self.cloudwatch_backend.delete_alarms(alarm_names) self.cloudwatch_backend.delete_alarms(alarm_names)
template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE) template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE)
return template.render() return template.render()
@amzn_request_id @amzn_request_id
def put_metric_data(self): def put_metric_data(self):
namespace = self._get_param('Namespace') namespace = self._get_param("Namespace")
metric_data = self._get_multi_param('MetricData.member') metric_data = self._get_multi_param("MetricData.member")
self.cloudwatch_backend.put_metric_data(namespace, metric_data) self.cloudwatch_backend.put_metric_data(namespace, metric_data)
template = self.response_template(PUT_METRIC_DATA_TEMPLATE) template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
@ -84,43 +92,52 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def get_metric_statistics(self): def get_metric_statistics(self):
namespace = self._get_param('Namespace') namespace = self._get_param("Namespace")
metric_name = self._get_param('MetricName') metric_name = self._get_param("MetricName")
start_time = dtparse(self._get_param('StartTime')) start_time = dtparse(self._get_param("StartTime"))
end_time = dtparse(self._get_param('EndTime')) end_time = dtparse(self._get_param("EndTime"))
period = int(self._get_param('Period')) period = int(self._get_param("Period"))
statistics = self._get_multi_param("Statistics.member") statistics = self._get_multi_param("Statistics.member")
# Unsupported Parameters (To Be Implemented) # Unsupported Parameters (To Be Implemented)
unit = self._get_param('Unit') unit = self._get_param("Unit")
extended_statistics = self._get_param('ExtendedStatistics') extended_statistics = self._get_param("ExtendedStatistics")
dimensions = self._get_param('Dimensions') dimensions = self._get_param("Dimensions")
if unit or extended_statistics or dimensions: if unit or extended_statistics or dimensions:
raise NotImplemented() raise NotImplementedError()
# TODO: this should instead throw InvalidParameterCombination # TODO: this should instead throw InvalidParameterCombination
if not statistics: if not statistics:
raise NotImplemented("Must specify either Statistics or ExtendedStatistics") raise NotImplementedError(
"Must specify either Statistics or ExtendedStatistics"
)
datapoints = self.cloudwatch_backend.get_metric_statistics(namespace, metric_name, start_time, end_time, period, statistics) datapoints = self.cloudwatch_backend.get_metric_statistics(
namespace, metric_name, start_time, end_time, period, statistics
)
template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE) template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE)
return template.render(label=metric_name, datapoints=datapoints) return template.render(label=metric_name, datapoints=datapoints)
@amzn_request_id @amzn_request_id
def list_metrics(self): def list_metrics(self):
metrics = self.cloudwatch_backend.get_all_metrics() namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
next_token = self._get_param("NextToken")
next_token, metrics = self.cloudwatch_backend.list_metrics(
next_token, namespace, metric_name
)
template = self.response_template(LIST_METRICS_TEMPLATE) template = self.response_template(LIST_METRICS_TEMPLATE)
return template.render(metrics=metrics) return template.render(metrics=metrics, next_token=next_token)
@amzn_request_id @amzn_request_id
def delete_dashboards(self): def delete_dashboards(self):
dashboards = self._get_multi_param('DashboardNames.member') dashboards = self._get_multi_param("DashboardNames.member")
if dashboards is None: if dashboards is None:
return self._error('InvalidParameterValue', 'Need at least 1 dashboard') return self._error("InvalidParameterValue", "Need at least 1 dashboard")
status, error = self.cloudwatch_backend.delete_dashboards(dashboards) status, error = self.cloudwatch_backend.delete_dashboards(dashboards)
if not status: if not status:
return self._error('ResourceNotFound', error) return self._error("ResourceNotFound", error)
template = self.response_template(DELETE_DASHBOARD_TEMPLATE) template = self.response_template(DELETE_DASHBOARD_TEMPLATE)
return template.render() return template.render()
@ -143,18 +160,18 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def get_dashboard(self): def get_dashboard(self):
dashboard_name = self._get_param('DashboardName') dashboard_name = self._get_param("DashboardName")
dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name) dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name)
if dashboard is None: if dashboard is None:
return self._error('ResourceNotFound', 'Dashboard does not exist') return self._error("ResourceNotFound", "Dashboard does not exist")
template = self.response_template(GET_DASHBOARD_TEMPLATE) template = self.response_template(GET_DASHBOARD_TEMPLATE)
return template.render(dashboard=dashboard) return template.render(dashboard=dashboard)
@amzn_request_id @amzn_request_id
def list_dashboards(self): def list_dashboards(self):
prefix = self._get_param('DashboardNamePrefix', '') prefix = self._get_param("DashboardNamePrefix", "")
dashboards = self.cloudwatch_backend.list_dashboards(prefix) dashboards = self.cloudwatch_backend.list_dashboards(prefix)
@ -163,13 +180,13 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def put_dashboard(self): def put_dashboard(self):
name = self._get_param('DashboardName') name = self._get_param("DashboardName")
body = self._get_param('DashboardBody') body = self._get_param("DashboardBody")
try: try:
json.loads(body) json.loads(body)
except ValueError: except ValueError:
return self._error('InvalidParameterInput', 'Body is invalid JSON') return self._error("InvalidParameterInput", "Body is invalid JSON")
self.cloudwatch_backend.put_dashboard(name, body) self.cloudwatch_backend.put_dashboard(name, body)
@ -178,12 +195,14 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def set_alarm_state(self): def set_alarm_state(self):
alarm_name = self._get_param('AlarmName') alarm_name = self._get_param("AlarmName")
reason = self._get_param('StateReason') reason = self._get_param("StateReason")
reason_data = self._get_param('StateReasonData') reason_data = self._get_param("StateReasonData")
state_value = self._get_param('StateValue') state_value = self._get_param("StateValue")
self.cloudwatch_backend.set_alarm_state(alarm_name, reason, reason_data, state_value) self.cloudwatch_backend.set_alarm_state(
alarm_name, reason, reason_data, state_value
)
template = self.response_template(SET_ALARM_STATE_TEMPLATE) template = self.response_template(SET_ALARM_STATE_TEMPLATE)
return template.render() return template.render()
@ -275,7 +294,7 @@ GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://m
<Label>{{ label }}</Label> <Label>{{ label }}</Label>
<Datapoints> <Datapoints>
{% for datapoint in datapoints %} {% for datapoint in datapoints %}
<Datapoint> <member>
{% if datapoint.sum is not none %} {% if datapoint.sum is not none %}
<Sum>{{ datapoint.sum }}</Sum> <Sum>{{ datapoint.sum }}</Sum>
{% endif %} {% endif %}
@ -302,7 +321,7 @@ GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://m
<Timestamp>{{ datapoint.timestamp }}</Timestamp> <Timestamp>{{ datapoint.timestamp }}</Timestamp>
<Unit>{{ datapoint.unit }}</Unit> <Unit>{{ datapoint.unit }}</Unit>
</Datapoint> </member>
{% endfor %} {% endfor %}
</Datapoints> </Datapoints>
</GetMetricStatisticsResult> </GetMetricStatisticsResult>
@ -326,9 +345,11 @@ LIST_METRICS_TEMPLATE = """<ListMetricsResponse xmlns="http://monitoring.amazona
</member> </member>
{% endfor %} {% endfor %}
</Metrics> </Metrics>
{% if next_token is not none %}
<NextToken> <NextToken>
96e88479-4662-450b-8a13-239ded6ce9fe {{ next_token }}
</NextToken> </NextToken>
{% endif %}
</ListMetricsResult> </ListMetricsResult>
</ListMetricsResponse>""" </ListMetricsResponse>"""

View File

@ -1,9 +1,5 @@
from .responses import CloudWatchResponse from .responses import CloudWatchResponse
url_bases = [ url_bases = ["https?://monitoring.(.+).amazonaws.com"]
"https?://monitoring.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": CloudWatchResponse.dispatch}
'{0}/$': CloudWatchResponse.dispatch,
}

View File

@ -0,0 +1,4 @@
from .models import codecommit_backends
from ..core.models import base_decorator
mock_codecommit = base_decorator(codecommit_backends)

View File

@ -0,0 +1,35 @@
from moto.core.exceptions import JsonRESTError
class RepositoryNameExistsException(JsonRESTError):
code = 400
def __init__(self, repository_name):
super(RepositoryNameExistsException, self).__init__(
"RepositoryNameExistsException",
"Repository named {0} already exists".format(repository_name),
)
class RepositoryDoesNotExistException(JsonRESTError):
code = 400
def __init__(self, repository_name):
super(RepositoryDoesNotExistException, self).__init__(
"RepositoryDoesNotExistException",
"{0} does not exist".format(repository_name),
)
class InvalidRepositoryNameException(JsonRESTError):
code = 400
def __init__(self):
super(InvalidRepositoryNameException, self).__init__(
"InvalidRepositoryNameException",
"The repository name is not valid. Repository names can be any valid "
"combination of letters, numbers, "
"periods, underscores, and dashes between 1 and 100 characters in "
"length. Names are case sensitive. "
"For more information, see Limits in the AWS CodeCommit User Guide. ",
)

69
moto/codecommit/models.py Normal file
View File

@ -0,0 +1,69 @@
from boto3 import Session
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from datetime import datetime
from moto.iam.models import ACCOUNT_ID
from .exceptions import RepositoryDoesNotExistException, RepositoryNameExistsException
import uuid
class CodeCommit(BaseModel):
def __init__(self, region, repository_description, repository_name):
current_date = iso_8601_datetime_with_milliseconds(datetime.utcnow())
self.repository_metadata = dict()
self.repository_metadata["repositoryName"] = repository_name
self.repository_metadata[
"cloneUrlSsh"
] = "ssh://git-codecommit.{0}.amazonaws.com/v1/repos/{1}".format(
region, repository_name
)
self.repository_metadata[
"cloneUrlHttp"
] = "https://git-codecommit.{0}.amazonaws.com/v1/repos/{1}".format(
region, repository_name
)
self.repository_metadata["creationDate"] = current_date
self.repository_metadata["lastModifiedDate"] = current_date
self.repository_metadata["repositoryDescription"] = repository_description
self.repository_metadata["repositoryId"] = str(uuid.uuid4())
self.repository_metadata["Arn"] = "arn:aws:codecommit:{0}:{1}:{2}".format(
region, ACCOUNT_ID, repository_name
)
self.repository_metadata["accountId"] = ACCOUNT_ID
class CodeCommitBackend(BaseBackend):
def __init__(self):
self.repositories = {}
def create_repository(self, region, repository_name, repository_description):
repository = self.repositories.get(repository_name)
if repository:
raise RepositoryNameExistsException(repository_name)
self.repositories[repository_name] = CodeCommit(
region, repository_description, repository_name
)
return self.repositories[repository_name].repository_metadata
def get_repository(self, repository_name):
repository = self.repositories.get(repository_name)
if not repository:
raise RepositoryDoesNotExistException(repository_name)
return repository.repository_metadata
def delete_repository(self, repository_name):
repository = self.repositories.get(repository_name)
if repository:
self.repositories.pop(repository_name)
return repository.repository_metadata.get("repositoryId")
return None
codecommit_backends = {}
for region in Session().get_available_regions("codecommit"):
codecommit_backends[region] = CodeCommitBackend()

View File

@ -0,0 +1,57 @@
import json
import re
from moto.core.responses import BaseResponse
from .models import codecommit_backends
from .exceptions import InvalidRepositoryNameException
def _is_repository_name_valid(repository_name):
name_regex = re.compile(r"[\w\.-]+")
result = name_regex.split(repository_name)
if len(result) > 0:
for match in result:
if len(match) > 0:
return False
return True
class CodeCommitResponse(BaseResponse):
@property
def codecommit_backend(self):
return codecommit_backends[self.region]
def create_repository(self):
if not _is_repository_name_valid(self._get_param("repositoryName")):
raise InvalidRepositoryNameException()
repository_metadata = self.codecommit_backend.create_repository(
self.region,
self._get_param("repositoryName"),
self._get_param("repositoryDescription"),
)
return json.dumps({"repositoryMetadata": repository_metadata})
def get_repository(self):
if not _is_repository_name_valid(self._get_param("repositoryName")):
raise InvalidRepositoryNameException()
repository_metadata = self.codecommit_backend.get_repository(
self._get_param("repositoryName")
)
return json.dumps({"repositoryMetadata": repository_metadata})
def delete_repository(self):
if not _is_repository_name_valid(self._get_param("repositoryName")):
raise InvalidRepositoryNameException()
repository_id = self.codecommit_backend.delete_repository(
self._get_param("repositoryName")
)
if repository_id:
return json.dumps({"repositoryId": repository_id})
return json.dumps({})

6
moto/codecommit/urls.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .responses import CodeCommitResponse
url_bases = ["https?://codecommit.(.+).amazonaws.com"]
url_paths = {"{0}/$": CodeCommitResponse.dispatch}

View File

@ -0,0 +1,4 @@
from .models import codepipeline_backends
from ..core.models import base_decorator
mock_codepipeline = base_decorator(codepipeline_backends)

View File

@ -0,0 +1,44 @@
from moto.core.exceptions import JsonRESTError
class InvalidStructureException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidStructureException, self).__init__(
"InvalidStructureException", message
)
class PipelineNotFoundException(JsonRESTError):
code = 400
def __init__(self, message):
super(PipelineNotFoundException, self).__init__(
"PipelineNotFoundException", message
)
class ResourceNotFoundException(JsonRESTError):
code = 400
def __init__(self, message):
super(ResourceNotFoundException, self).__init__(
"ResourceNotFoundException", message
)
class InvalidTagsException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidTagsException, self).__init__("InvalidTagsException", message)
class TooManyTagsException(JsonRESTError):
code = 400
def __init__(self, arn):
super(TooManyTagsException, self).__init__(
"TooManyTagsException", "Tag limit exceeded for resource [{}].".format(arn)
)

218
moto/codepipeline/models.py Normal file
View File

@ -0,0 +1,218 @@
import json
from datetime import datetime
from boto3 import Session
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.iam.exceptions import IAMNotFoundException
from moto.iam import iam_backends
from moto.codepipeline.exceptions import (
InvalidStructureException,
PipelineNotFoundException,
ResourceNotFoundException,
InvalidTagsException,
TooManyTagsException,
)
from moto.core import BaseBackend, BaseModel
from moto.iam.models import ACCOUNT_ID
class CodePipeline(BaseModel):
def __init__(self, region, pipeline):
# the version number for a new pipeline is always 1
pipeline["version"] = 1
self.pipeline = self.add_default_values(pipeline)
self.tags = {}
self._arn = "arn:aws:codepipeline:{0}:{1}:{2}".format(
region, ACCOUNT_ID, pipeline["name"]
)
self._created = datetime.utcnow()
self._updated = datetime.utcnow()
@property
def metadata(self):
return {
"pipelineArn": self._arn,
"created": iso_8601_datetime_with_milliseconds(self._created),
"updated": iso_8601_datetime_with_milliseconds(self._updated),
}
def add_default_values(self, pipeline):
for stage in pipeline["stages"]:
for action in stage["actions"]:
if "runOrder" not in action:
action["runOrder"] = 1
if "configuration" not in action:
action["configuration"] = {}
if "outputArtifacts" not in action:
action["outputArtifacts"] = []
if "inputArtifacts" not in action:
action["inputArtifacts"] = []
return pipeline
def validate_tags(self, tags):
for tag in tags:
if tag["key"].startswith("aws:"):
raise InvalidTagsException(
"Not allowed to modify system tags. "
"System tags start with 'aws:'. "
"msg=[Caller is an end user and not allowed to mutate system tags]"
)
if (len(self.tags) + len(tags)) > 50:
raise TooManyTagsException(self._arn)
class CodePipelineBackend(BaseBackend):
def __init__(self):
self.pipelines = {}
@property
def iam_backend(self):
return iam_backends["global"]
def create_pipeline(self, region, pipeline, tags):
if pipeline["name"] in self.pipelines:
raise InvalidStructureException(
"A pipeline with the name '{0}' already exists in account '{1}'".format(
pipeline["name"], ACCOUNT_ID
)
)
try:
role = self.iam_backend.get_role_by_arn(pipeline["roleArn"])
service_principal = json.loads(role.assume_role_policy_document)[
"Statement"
][0]["Principal"]["Service"]
if "codepipeline.amazonaws.com" not in service_principal:
raise IAMNotFoundException("")
except IAMNotFoundException:
raise InvalidStructureException(
"CodePipeline is not authorized to perform AssumeRole on role {}".format(
pipeline["roleArn"]
)
)
if len(pipeline["stages"]) < 2:
raise InvalidStructureException(
"Pipeline has only 1 stage(s). There should be a minimum of 2 stages in a pipeline"
)
self.pipelines[pipeline["name"]] = CodePipeline(region, pipeline)
if tags:
self.pipelines[pipeline["name"]].validate_tags(tags)
new_tags = {tag["key"]: tag["value"] for tag in tags}
self.pipelines[pipeline["name"]].tags.update(new_tags)
return pipeline, sorted(tags, key=lambda i: i["key"])
def get_pipeline(self, name):
codepipeline = self.pipelines.get(name)
if not codepipeline:
raise PipelineNotFoundException(
"Account '{0}' does not have a pipeline with name '{1}'".format(
ACCOUNT_ID, name
)
)
return codepipeline.pipeline, codepipeline.metadata
def update_pipeline(self, pipeline):
codepipeline = self.pipelines.get(pipeline["name"])
if not codepipeline:
raise ResourceNotFoundException(
"The account with id '{0}' does not include a pipeline with the name '{1}'".format(
ACCOUNT_ID, pipeline["name"]
)
)
# version number is auto incremented
pipeline["version"] = codepipeline.pipeline["version"] + 1
codepipeline._updated = datetime.utcnow()
codepipeline.pipeline = codepipeline.add_default_values(pipeline)
return codepipeline.pipeline
def list_pipelines(self):
pipelines = []
for name, codepipeline in self.pipelines.items():
pipelines.append(
{
"name": name,
"version": codepipeline.pipeline["version"],
"created": codepipeline.metadata["created"],
"updated": codepipeline.metadata["updated"],
}
)
return sorted(pipelines, key=lambda i: i["name"])
def delete_pipeline(self, name):
self.pipelines.pop(name, None)
def list_tags_for_resource(self, arn):
name = arn.split(":")[-1]
pipeline = self.pipelines.get(name)
if not pipeline:
raise ResourceNotFoundException(
"The account with id '{0}' does not include a pipeline with the name '{1}'".format(
ACCOUNT_ID, name
)
)
tags = [{"key": key, "value": value} for key, value in pipeline.tags.items()]
return sorted(tags, key=lambda i: i["key"])
def tag_resource(self, arn, tags):
name = arn.split(":")[-1]
pipeline = self.pipelines.get(name)
if not pipeline:
raise ResourceNotFoundException(
"The account with id '{0}' does not include a pipeline with the name '{1}'".format(
ACCOUNT_ID, name
)
)
pipeline.validate_tags(tags)
for tag in tags:
pipeline.tags.update({tag["key"]: tag["value"]})
def untag_resource(self, arn, tag_keys):
name = arn.split(":")[-1]
pipeline = self.pipelines.get(name)
if not pipeline:
raise ResourceNotFoundException(
"The account with id '{0}' does not include a pipeline with the name '{1}'".format(
ACCOUNT_ID, name
)
)
for key in tag_keys:
pipeline.tags.pop(key, None)
codepipeline_backends = {}
for region in Session().get_available_regions("codepipeline"):
codepipeline_backends[region] = CodePipelineBackend()
for region in Session().get_available_regions(
"codepipeline", partition_name="aws-us-gov"
):
codepipeline_backends[region] = CodePipelineBackend()
for region in Session().get_available_regions("codepipeline", partition_name="aws-cn"):
codepipeline_backends[region] = CodePipelineBackend()

View File

@ -0,0 +1,62 @@
import json
from moto.core.responses import BaseResponse
from .models import codepipeline_backends
class CodePipelineResponse(BaseResponse):
@property
def codepipeline_backend(self):
return codepipeline_backends[self.region]
def create_pipeline(self):
pipeline, tags = self.codepipeline_backend.create_pipeline(
self.region, self._get_param("pipeline"), self._get_param("tags")
)
return json.dumps({"pipeline": pipeline, "tags": tags})
def get_pipeline(self):
pipeline, metadata = self.codepipeline_backend.get_pipeline(
self._get_param("name")
)
return json.dumps({"pipeline": pipeline, "metadata": metadata})
def update_pipeline(self):
pipeline = self.codepipeline_backend.update_pipeline(
self._get_param("pipeline")
)
return json.dumps({"pipeline": pipeline})
def list_pipelines(self):
pipelines = self.codepipeline_backend.list_pipelines()
return json.dumps({"pipelines": pipelines})
def delete_pipeline(self):
self.codepipeline_backend.delete_pipeline(self._get_param("name"))
return ""
def list_tags_for_resource(self):
tags = self.codepipeline_backend.list_tags_for_resource(
self._get_param("resourceArn")
)
return json.dumps({"tags": tags})
def tag_resource(self):
self.codepipeline_backend.tag_resource(
self._get_param("resourceArn"), self._get_param("tags")
)
return ""
def untag_resource(self):
self.codepipeline_backend.untag_resource(
self._get_param("resourceArn"), self._get_param("tagKeys")
)
return ""

View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .responses import CodePipelineResponse
url_bases = ["https?://codepipeline.(.+).amazonaws.com"]
url_paths = {"{0}/$": CodePipelineResponse.dispatch}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import cognitoidentity_backends from .models import cognitoidentity_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
cognitoidentity_backend = cognitoidentity_backends['us-east-1'] cognitoidentity_backend = cognitoidentity_backends["us-east-1"]
mock_cognitoidentity = base_decorator(cognitoidentity_backends) mock_cognitoidentity = base_decorator(cognitoidentity_backends)
mock_cognitoidentity_deprecated = deprecated_base_decorator(cognitoidentity_backends) mock_cognitoidentity_deprecated = deprecated_base_decorator(cognitoidentity_backends)

View File

@ -0,0 +1,13 @@
from __future__ import unicode_literals
import json
from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest):
def __init__(self, message):
super(ResourceNotFoundError, self).__init__()
self.description = json.dumps(
{"message": message, "__type": "ResourceNotFoundException"}
)

View File

@ -3,32 +3,34 @@ from __future__ import unicode_literals
import datetime import datetime
import json import json
import boto.cognito.identity from boto3 import Session
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
from .exceptions import ResourceNotFoundError
from .utils import get_random_identity_id from .utils import get_random_identity_id
class CognitoIdentity(BaseModel): class CognitoIdentity(BaseModel):
def __init__(self, region, identity_pool_name, **kwargs): def __init__(self, region, identity_pool_name, **kwargs):
self.identity_pool_name = identity_pool_name self.identity_pool_name = identity_pool_name
self.allow_unauthenticated_identities = kwargs.get('allow_unauthenticated_identities', '') self.allow_unauthenticated_identities = kwargs.get(
self.supported_login_providers = kwargs.get('supported_login_providers', {}) "allow_unauthenticated_identities", ""
self.developer_provider_name = kwargs.get('developer_provider_name', '') )
self.open_id_connect_provider_arns = kwargs.get('open_id_connect_provider_arns', []) self.supported_login_providers = kwargs.get("supported_login_providers", {})
self.cognito_identity_providers = kwargs.get('cognito_identity_providers', []) self.developer_provider_name = kwargs.get("developer_provider_name", "")
self.saml_provider_arns = kwargs.get('saml_provider_arns', []) self.open_id_connect_provider_arns = kwargs.get(
"open_id_connect_provider_arns", []
)
self.cognito_identity_providers = kwargs.get("cognito_identity_providers", [])
self.saml_provider_arns = kwargs.get("saml_provider_arns", [])
self.identity_pool_id = get_random_identity_id(region) self.identity_pool_id = get_random_identity_id(region)
self.creation_time = datetime.datetime.utcnow() self.creation_time = datetime.datetime.utcnow()
class CognitoIdentityBackend(BaseBackend): class CognitoIdentityBackend(BaseBackend):
def __init__(self, region): def __init__(self, region):
super(CognitoIdentityBackend, self).__init__() super(CognitoIdentityBackend, self).__init__()
self.region = region self.region = region
@ -39,34 +41,67 @@ class CognitoIdentityBackend(BaseBackend):
self.__dict__ = {} self.__dict__ = {}
self.__init__(region) self.__init__(region)
def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities, def describe_identity_pool(self, identity_pool_id):
supported_login_providers, developer_provider_name, open_id_connect_provider_arns, identity_pool = self.identity_pools.get(identity_pool_id, None)
cognito_identity_providers, saml_provider_arns):
new_identity = CognitoIdentity(self.region, identity_pool_name, if not identity_pool:
raise ResourceNotFoundError(identity_pool)
response = json.dumps(
{
"AllowUnauthenticatedIdentities": identity_pool.allow_unauthenticated_identities,
"CognitoIdentityProviders": identity_pool.cognito_identity_providers,
"DeveloperProviderName": identity_pool.developer_provider_name,
"IdentityPoolId": identity_pool.identity_pool_id,
"IdentityPoolName": identity_pool.identity_pool_name,
"IdentityPoolTags": {},
"OpenIdConnectProviderARNs": identity_pool.open_id_connect_provider_arns,
"SamlProviderARNs": identity_pool.saml_provider_arns,
"SupportedLoginProviders": identity_pool.supported_login_providers,
}
)
return response
def create_identity_pool(
self,
identity_pool_name,
allow_unauthenticated_identities,
supported_login_providers,
developer_provider_name,
open_id_connect_provider_arns,
cognito_identity_providers,
saml_provider_arns,
):
new_identity = CognitoIdentity(
self.region,
identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities, allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers, supported_login_providers=supported_login_providers,
developer_provider_name=developer_provider_name, developer_provider_name=developer_provider_name,
open_id_connect_provider_arns=open_id_connect_provider_arns, open_id_connect_provider_arns=open_id_connect_provider_arns,
cognito_identity_providers=cognito_identity_providers, cognito_identity_providers=cognito_identity_providers,
saml_provider_arns=saml_provider_arns) saml_provider_arns=saml_provider_arns,
)
self.identity_pools[new_identity.identity_pool_id] = new_identity self.identity_pools[new_identity.identity_pool_id] = new_identity
response = json.dumps({ response = json.dumps(
'IdentityPoolId': new_identity.identity_pool_id, {
'IdentityPoolName': new_identity.identity_pool_name, "IdentityPoolId": new_identity.identity_pool_id,
'AllowUnauthenticatedIdentities': new_identity.allow_unauthenticated_identities, "IdentityPoolName": new_identity.identity_pool_name,
'SupportedLoginProviders': new_identity.supported_login_providers, "AllowUnauthenticatedIdentities": new_identity.allow_unauthenticated_identities,
'DeveloperProviderName': new_identity.developer_provider_name, "SupportedLoginProviders": new_identity.supported_login_providers,
'OpenIdConnectProviderARNs': new_identity.open_id_connect_provider_arns, "DeveloperProviderName": new_identity.developer_provider_name,
'CognitoIdentityProviders': new_identity.cognito_identity_providers, "OpenIdConnectProviderARNs": new_identity.open_id_connect_provider_arns,
'SamlProviderARNs': new_identity.saml_provider_arns "CognitoIdentityProviders": new_identity.cognito_identity_providers,
}) "SamlProviderARNs": new_identity.saml_provider_arns,
}
)
return response return response
def get_id(self): def get_id(self):
identity_id = {'IdentityId': get_random_identity_id(self.region)} identity_id = {"IdentityId": get_random_identity_id(self.region)}
return json.dumps(identity_id) return json.dumps(identity_id)
def get_credentials_for_identity(self, identity_id): def get_credentials_for_identity(self, identity_id):
@ -76,26 +111,38 @@ class CognitoIdentityBackend(BaseBackend):
expiration_str = str(iso_8601_datetime_with_milliseconds(expiration)) expiration_str = str(iso_8601_datetime_with_milliseconds(expiration))
response = json.dumps( response = json.dumps(
{ {
"Credentials": "Credentials": {
{
"AccessKeyId": "TESTACCESSKEY12345", "AccessKeyId": "TESTACCESSKEY12345",
"Expiration": expiration_str, "Expiration": expiration_str,
"SecretKey": "ABCSECRETKEY", "SecretKey": "ABCSECRETKEY",
"SessionToken": "ABC12345" "SessionToken": "ABC12345",
}, },
"IdentityId": identity_id "IdentityId": identity_id,
}) }
)
return response return response
def get_open_id_token_for_developer_identity(self, identity_id): def get_open_id_token_for_developer_identity(self, identity_id):
response = json.dumps( response = json.dumps(
{ {"IdentityId": identity_id, "Token": get_random_identity_id(self.region)}
"IdentityId": identity_id, )
"Token": get_random_identity_id(self.region) return response
})
def get_open_id_token(self, identity_id):
response = json.dumps(
{"IdentityId": identity_id, "Token": get_random_identity_id(self.region)}
)
return response return response
cognitoidentity_backends = {} cognitoidentity_backends = {}
for region in boto.cognito.identity.regions(): for region in Session().get_available_regions("cognito-identity"):
cognitoidentity_backends[region.name] = CognitoIdentityBackend(region.name) cognitoidentity_backends[region] = CognitoIdentityBackend(region)
for region in Session().get_available_regions(
"cognito-identity", partition_name="aws-us-gov"
):
cognitoidentity_backends[region] = CognitoIdentityBackend(region)
for region in Session().get_available_regions(
"cognito-identity", partition_name="aws-cn"
):
cognitoidentity_backends[region] = CognitoIdentityBackend(region)

View File

@ -1,21 +1,22 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import cognitoidentity_backends from .models import cognitoidentity_backends
from .utils import get_random_identity_id from .utils import get_random_identity_id
class CognitoIdentityResponse(BaseResponse): class CognitoIdentityResponse(BaseResponse):
def create_identity_pool(self): def create_identity_pool(self):
identity_pool_name = self._get_param('IdentityPoolName') identity_pool_name = self._get_param("IdentityPoolName")
allow_unauthenticated_identities = self._get_param('AllowUnauthenticatedIdentities') allow_unauthenticated_identities = self._get_param(
supported_login_providers = self._get_param('SupportedLoginProviders') "AllowUnauthenticatedIdentities"
developer_provider_name = self._get_param('DeveloperProviderName') )
open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs') supported_login_providers = self._get_param("SupportedLoginProviders")
cognito_identity_providers = self._get_param('CognitoIdentityProviders') developer_provider_name = self._get_param("DeveloperProviderName")
saml_provider_arns = self._get_param('SamlProviderARNs') open_id_connect_provider_arns = self._get_param("OpenIdConnectProviderARNs")
cognito_identity_providers = self._get_param("CognitoIdentityProviders")
saml_provider_arns = self._get_param("SamlProviderARNs")
return cognitoidentity_backends[self.region].create_identity_pool( return cognitoidentity_backends[self.region].create_identity_pool(
identity_pool_name=identity_pool_name, identity_pool_name=identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities, allow_unauthenticated_identities=allow_unauthenticated_identities,
@ -23,15 +24,30 @@ class CognitoIdentityResponse(BaseResponse):
developer_provider_name=developer_provider_name, developer_provider_name=developer_provider_name,
open_id_connect_provider_arns=open_id_connect_provider_arns, open_id_connect_provider_arns=open_id_connect_provider_arns,
cognito_identity_providers=cognito_identity_providers, cognito_identity_providers=cognito_identity_providers,
saml_provider_arns=saml_provider_arns) saml_provider_arns=saml_provider_arns,
)
def get_id(self): def get_id(self):
return cognitoidentity_backends[self.region].get_id() return cognitoidentity_backends[self.region].get_id()
def describe_identity_pool(self):
return cognitoidentity_backends[self.region].describe_identity_pool(
self._get_param("IdentityPoolId")
)
def get_credentials_for_identity(self): def get_credentials_for_identity(self):
return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId')) return cognitoidentity_backends[self.region].get_credentials_for_identity(
self._get_param("IdentityId")
)
def get_open_id_token_for_developer_identity(self): def get_open_id_token_for_developer_identity(self):
return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity( return cognitoidentity_backends[
self._get_param('IdentityId') or get_random_identity_id(self.region) self.region
].get_open_id_token_for_developer_identity(
self._get_param("IdentityId") or get_random_identity_id(self.region)
)
def get_open_id_token(self):
return cognitoidentity_backends[self.region].get_open_id_token(
self._get_param("IdentityId") or get_random_identity_id(self.region)
) )

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import CognitoIdentityResponse from .responses import CognitoIdentityResponse
url_bases = [ url_bases = ["https?://cognito-identity.(.+).amazonaws.com"]
"https?://cognito-identity.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": CognitoIdentityResponse.dispatch}
'{0}/$': CognitoIdentityResponse.dispatch,
}

View File

@ -5,40 +5,40 @@ from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest): class ResourceNotFoundError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(ResourceNotFoundError, self).__init__() super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "ResourceNotFoundException"}
'__type': 'ResourceNotFoundException', )
})
class UserNotFoundError(BadRequest): class UserNotFoundError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(UserNotFoundError, self).__init__() super(UserNotFoundError, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "UserNotFoundException"}
'__type': 'UserNotFoundException', )
})
class UsernameExistsException(BadRequest):
def __init__(self, message):
super(UsernameExistsException, self).__init__()
self.description = json.dumps(
{"message": message, "__type": "UsernameExistsException"}
)
class GroupExistsException(BadRequest): class GroupExistsException(BadRequest):
def __init__(self, message): def __init__(self, message):
super(GroupExistsException, self).__init__() super(GroupExistsException, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "GroupExistsException"}
'__type': 'GroupExistsException', )
})
class NotAuthorizedError(BadRequest): class NotAuthorizedError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(NotAuthorizedError, self).__init__() super(NotAuthorizedError, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "NotAuthorizedException"}
'__type': 'NotAuthorizedException', )
})

View File

@ -2,18 +2,25 @@ from __future__ import unicode_literals
import datetime import datetime
import functools import functools
import hashlib
import itertools import itertools
import json import json
import os import os
import time import time
import uuid import uuid
import boto.cognito.identity from boto3 import Session
from jose import jws from jose import jws
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from .exceptions import GroupExistsException, NotAuthorizedError, ResourceNotFoundError, UserNotFoundError from .exceptions import (
GroupExistsException,
NotAuthorizedError,
ResourceNotFoundError,
UserNotFoundError,
UsernameExistsException,
)
UserStatus = { UserStatus = {
"FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD", "FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD",
@ -43,19 +50,22 @@ def paginate(limit, start_arg="next_token", limit_arg="max_results"):
def outer_wrapper(func): def outer_wrapper(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start = int(default_start if kwargs.get(start_arg) is None else kwargs[start_arg]) start = int(
default_start if kwargs.get(start_arg) is None else kwargs[start_arg]
)
lim = int(limit if kwargs.get(limit_arg) is None else kwargs[limit_arg]) lim = int(limit if kwargs.get(limit_arg) is None else kwargs[limit_arg])
stop = start + lim stop = start + lim
result = func(*args, **kwargs) result = func(*args, **kwargs)
limited_results = list(itertools.islice(result, start, stop)) limited_results = list(itertools.islice(result, start, stop))
next_token = stop if stop < len(result) else None next_token = stop if stop < len(result) else None
return limited_results, next_token return limited_results, next_token
return wrapper return wrapper
return outer_wrapper return outer_wrapper
class CognitoIdpUserPool(BaseModel): class CognitoIdpUserPool(BaseModel):
def __init__(self, region, name, extended_config): def __init__(self, region, name, extended_config):
self.region = region self.region = region
self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex)) self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex))
@ -73,7 +83,9 @@ class CognitoIdpUserPool(BaseModel):
self.access_tokens = {} self.access_tokens = {}
self.id_tokens = {} self.id_tokens = {}
with open(os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")) as f: with open(
os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")
) as f:
self.json_web_key = json.loads(f.read()) self.json_web_key = json.loads(f.read())
def _base_json(self): def _base_json(self):
@ -90,26 +102,35 @@ class CognitoIdpUserPool(BaseModel):
if extended: if extended:
user_pool_json.update(self.extended_config) user_pool_json.update(self.extended_config)
else: else:
user_pool_json["LambdaConfig"] = self.extended_config.get("LambdaConfig") or {} user_pool_json["LambdaConfig"] = (
self.extended_config.get("LambdaConfig") or {}
)
return user_pool_json return user_pool_json
def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}): def create_jwt(
self, client_id, username, token_use, expires_in=60 * 60, extra_data={}
):
now = int(time.time()) now = int(time.time())
payload = { payload = {
"iss": "https://cognito-idp.{}.amazonaws.com/{}".format(self.region, self.id), "iss": "https://cognito-idp.{}.amazonaws.com/{}".format(
self.region, self.id
),
"sub": self.users[username].id, "sub": self.users[username].id,
"aud": client_id, "aud": client_id,
"token_use": "id", "token_use": token_use,
"auth_time": now, "auth_time": now,
"exp": now + expires_in, "exp": now + expires_in,
} }
payload.update(extra_data) payload.update(extra_data)
return jws.sign(payload, self.json_web_key, algorithm='RS256'), expires_in return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in
def create_id_token(self, client_id, username): def create_id_token(self, client_id, username):
id_token, expires_in = self.create_jwt(client_id, username) extra_data = self.get_user_extra_data_by_client_id(client_id, username)
id_token, expires_in = self.create_jwt(
client_id, username, "id", extra_data=extra_data
)
self.id_tokens[id_token] = (client_id, username) self.id_tokens[id_token] = (client_id, username)
return id_token, expires_in return id_token, expires_in
@ -119,11 +140,7 @@ class CognitoIdpUserPool(BaseModel):
return refresh_token return refresh_token
def create_access_token(self, client_id, username): def create_access_token(self, client_id, username):
extra_data = self.get_user_extra_data_by_client_id( access_token, expires_in = self.create_jwt(client_id, username, "access")
client_id, username
)
access_token, expires_in = self.create_jwt(client_id, username,
extra_data=extra_data)
self.access_tokens[access_token] = (client_id, username) self.access_tokens[access_token] = (client_id, username)
return access_token, expires_in return access_token, expires_in
@ -141,37 +158,49 @@ class CognitoIdpUserPool(BaseModel):
current_client = self.clients.get(client_id, None) current_client = self.clients.get(client_id, None)
if current_client: if current_client:
for readable_field in current_client.get_readable_fields(): for readable_field in current_client.get_readable_fields():
attribute = list(filter( attribute = list(
lambda f: f['Name'] == readable_field, filter(
self.users.get(username).attributes lambda f: f["Name"] == readable_field,
)) self.users.get(username).attributes,
)
)
if len(attribute) > 0: if len(attribute) > 0:
extra_data.update({ extra_data.update({attribute[0]["Name"]: attribute[0]["Value"]})
attribute[0]['Name']: attribute[0]['Value']
})
return extra_data return extra_data
class CognitoIdpUserPoolDomain(BaseModel): class CognitoIdpUserPoolDomain(BaseModel):
def __init__(self, user_pool_id, domain, custom_domain_config=None):
def __init__(self, user_pool_id, domain):
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
self.domain = domain self.domain = domain
self.custom_domain_config = custom_domain_config or {}
def to_json(self): def _distribution_name(self):
return { if self.custom_domain_config and "CertificateArn" in self.custom_domain_config:
"UserPoolId": self.user_pool_id, hash = hashlib.md5(
"AWSAccountId": str(uuid.uuid4()), self.custom_domain_config["CertificateArn"].encode("utf-8")
"CloudFrontDistribution": None, ).hexdigest()
"Domain": self.domain, return "{hash}.cloudfront.net".format(hash=hash[:16])
"S3Bucket": None, return None
"Status": "ACTIVE",
"Version": None, def to_json(self, extended=True):
} distribution = self._distribution_name()
if extended:
return {
"UserPoolId": self.user_pool_id,
"AWSAccountId": str(uuid.uuid4()),
"CloudFrontDistribution": distribution,
"Domain": self.domain,
"S3Bucket": None,
"Status": "ACTIVE",
"Version": None,
}
elif distribution:
return {"CloudFrontDomain": distribution}
return None
class CognitoIdpUserPoolClient(BaseModel): class CognitoIdpUserPoolClient(BaseModel):
def __init__(self, user_pool_id, extended_config): def __init__(self, user_pool_id, extended_config):
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
self.id = str(uuid.uuid4()) self.id = str(uuid.uuid4())
@ -193,11 +222,10 @@ class CognitoIdpUserPoolClient(BaseModel):
return user_pool_client_json return user_pool_client_json
def get_readable_fields(self): def get_readable_fields(self):
return self.extended_config.get('ReadAttributes', []) return self.extended_config.get("ReadAttributes", [])
class CognitoIdpIdentityProvider(BaseModel): class CognitoIdpIdentityProvider(BaseModel):
def __init__(self, name, extended_config): def __init__(self, name, extended_config):
self.name = name self.name = name
self.extended_config = extended_config or {} self.extended_config = extended_config or {}
@ -221,7 +249,6 @@ class CognitoIdpIdentityProvider(BaseModel):
class CognitoIdpGroup(BaseModel): class CognitoIdpGroup(BaseModel):
def __init__(self, user_pool_id, group_name, description, role_arn, precedence): def __init__(self, user_pool_id, group_name, description, role_arn, precedence):
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
self.group_name = group_name self.group_name = group_name
@ -248,7 +275,6 @@ class CognitoIdpGroup(BaseModel):
class CognitoIdpUser(BaseModel): class CognitoIdpUser(BaseModel):
def __init__(self, user_pool_id, username, password, status, attributes): def __init__(self, user_pool_id, username, password, status, attributes):
self.id = str(uuid.uuid4()) self.id = str(uuid.uuid4())
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
@ -281,15 +307,25 @@ class CognitoIdpUser(BaseModel):
{ {
"Enabled": self.enabled, "Enabled": self.enabled,
attributes_key: self.attributes, attributes_key: self.attributes,
"MFAOptions": [] "MFAOptions": [],
} }
) )
return user_json return user_json
def update_attributes(self, new_attributes):
def flatten_attrs(attrs):
return {attr["Name"]: attr["Value"] for attr in attrs}
def expand_attrs(attrs):
return [{"Name": k, "Value": v} for k, v in attrs.items()]
flat_attributes = flatten_attrs(self.attributes)
flat_attributes.update(flatten_attrs(new_attributes))
self.attributes = expand_attrs(flat_attributes)
class CognitoIdpBackend(BaseBackend): class CognitoIdpBackend(BaseBackend):
def __init__(self, region): def __init__(self, region):
super(CognitoIdpBackend, self).__init__() super(CognitoIdpBackend, self).__init__()
self.region = region self.region = region
@ -326,11 +362,13 @@ class CognitoIdpBackend(BaseBackend):
del self.user_pools[user_pool_id] del self.user_pools[user_pool_id]
# User pool domain # User pool domain
def create_user_pool_domain(self, user_pool_id, domain): def create_user_pool_domain(self, user_pool_id, domain, custom_domain_config=None):
if user_pool_id not in self.user_pools: if user_pool_id not in self.user_pools:
raise ResourceNotFoundError(user_pool_id) raise ResourceNotFoundError(user_pool_id)
user_pool_domain = CognitoIdpUserPoolDomain(user_pool_id, domain) user_pool_domain = CognitoIdpUserPoolDomain(
user_pool_id, domain, custom_domain_config=custom_domain_config
)
self.user_pool_domains[domain] = user_pool_domain self.user_pool_domains[domain] = user_pool_domain
return user_pool_domain return user_pool_domain
@ -346,6 +384,14 @@ class CognitoIdpBackend(BaseBackend):
del self.user_pool_domains[domain] del self.user_pool_domains[domain]
def update_user_pool_domain(self, domain, custom_domain_config):
if domain not in self.user_pool_domains:
raise ResourceNotFoundError(domain)
user_pool_domain = self.user_pool_domains[domain]
user_pool_domain.custom_domain_config = custom_domain_config
return user_pool_domain
# User pool client # User pool client
def create_user_pool_client(self, user_pool_id, extended_config): def create_user_pool_client(self, user_pool_id, extended_config):
user_pool = self.user_pools.get(user_pool_id) user_pool = self.user_pools.get(user_pool_id)
@ -455,7 +501,9 @@ class CognitoIdpBackend(BaseBackend):
if not user_pool: if not user_pool:
raise ResourceNotFoundError(user_pool_id) raise ResourceNotFoundError(user_pool_id)
group = CognitoIdpGroup(user_pool_id, group_name, description, role_arn, precedence) group = CognitoIdpGroup(
user_pool_id, group_name, description, role_arn, precedence
)
if group.group_name in user_pool.groups: if group.group_name in user_pool.groups:
raise GroupExistsException("A group with the name already exists") raise GroupExistsException("A group with the name already exists")
user_pool.groups[group.group_name] = group user_pool.groups[group.group_name] = group
@ -521,7 +569,16 @@ class CognitoIdpBackend(BaseBackend):
if not user_pool: if not user_pool:
raise ResourceNotFoundError(user_pool_id) raise ResourceNotFoundError(user_pool_id)
user = CognitoIdpUser(user_pool_id, username, temporary_password, UserStatus["FORCE_CHANGE_PASSWORD"], attributes) if username in user_pool.users:
raise UsernameExistsException(username)
user = CognitoIdpUser(
user_pool_id,
username,
temporary_password,
UserStatus["FORCE_CHANGE_PASSWORD"],
attributes,
)
user_pool.users[user.username] = user user_pool.users[user.username] = user
return user return user
@ -567,7 +624,9 @@ class CognitoIdpBackend(BaseBackend):
def _log_user_in(self, user_pool, client, username): def _log_user_in(self, user_pool, client, username):
refresh_token = user_pool.create_refresh_token(client.id, username) refresh_token = user_pool.create_refresh_token(client.id, username)
access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(
refresh_token
)
return { return {
"AuthenticationResult": { "AuthenticationResult": {
@ -610,7 +669,11 @@ class CognitoIdpBackend(BaseBackend):
return self._log_user_in(user_pool, client, username) return self._log_user_in(user_pool, client, username)
elif auth_flow == "REFRESH_TOKEN": elif auth_flow == "REFRESH_TOKEN":
refresh_token = auth_parameters.get("REFRESH_TOKEN") refresh_token = auth_parameters.get("REFRESH_TOKEN")
id_token, access_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) (
id_token,
access_token,
expires_in,
) = user_pool.create_tokens_from_refresh_token(refresh_token)
return { return {
"AuthenticationResult": { "AuthenticationResult": {
@ -622,7 +685,9 @@ class CognitoIdpBackend(BaseBackend):
else: else:
return {} return {}
def respond_to_auth_challenge(self, session, client_id, challenge_name, challenge_responses): def respond_to_auth_challenge(
self, session, client_id, challenge_name, challenge_responses
):
user_pool = self.sessions.get(session) user_pool = self.sessions.get(session)
if not user_pool: if not user_pool:
raise ResourceNotFoundError(session) raise ResourceNotFoundError(session)
@ -673,10 +738,27 @@ class CognitoIdpBackend(BaseBackend):
else: else:
raise NotAuthorizedError(access_token) raise NotAuthorizedError(access_token)
def admin_update_user_attributes(self, user_pool_id, username, attributes):
user_pool = self.user_pools.get(user_pool_id)
if not user_pool:
raise ResourceNotFoundError(user_pool_id)
if username not in user_pool.users:
raise UserNotFoundError(username)
user = user_pool.users[username]
user.update_attributes(attributes)
cognitoidp_backends = {} cognitoidp_backends = {}
for region in boto.cognito.identity.regions(): for region in Session().get_available_regions("cognito-idp"):
cognitoidp_backends[region.name] = CognitoIdpBackend(region.name) cognitoidp_backends[region] = CognitoIdpBackend(region)
for region in Session().get_available_regions(
"cognito-idp", partition_name="aws-us-gov"
):
cognitoidp_backends[region] = CognitoIdpBackend(region)
for region in Session().get_available_regions("cognito-idp", partition_name="aws-cn"):
cognitoidp_backends[region] = CognitoIdpBackend(region)
# Hack to help moto-server process requests on localhost, where the region isn't # Hack to help moto-server process requests on localhost, where the region isn't

View File

@ -8,7 +8,6 @@ from .models import cognitoidp_backends, find_region_by_value
class CognitoIdpResponse(BaseResponse): class CognitoIdpResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
return json.loads(self.body) return json.loads(self.body)
@ -16,10 +15,10 @@ class CognitoIdpResponse(BaseResponse):
# User pool # User pool
def create_user_pool(self): def create_user_pool(self):
name = self.parameters.pop("PoolName") name = self.parameters.pop("PoolName")
user_pool = cognitoidp_backends[self.region].create_user_pool(name, self.parameters) user_pool = cognitoidp_backends[self.region].create_user_pool(
return json.dumps({ name, self.parameters
"UserPool": user_pool.to_json(extended=True) )
}) return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def list_user_pools(self): def list_user_pools(self):
max_results = self._get_param("MaxResults") max_results = self._get_param("MaxResults")
@ -27,9 +26,7 @@ class CognitoIdpResponse(BaseResponse):
user_pools, next_token = cognitoidp_backends[self.region].list_user_pools( user_pools, next_token = cognitoidp_backends[self.region].list_user_pools(
max_results=max_results, next_token=next_token max_results=max_results, next_token=next_token
) )
response = { response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]}
"UserPools": [user_pool.to_json() for user_pool in user_pools],
}
if next_token: if next_token:
response["NextToken"] = str(next_token) response["NextToken"] = str(next_token)
return json.dumps(response) return json.dumps(response)
@ -37,9 +34,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool(self): def describe_user_pool(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id) user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id)
return json.dumps({ return json.dumps({"UserPool": user_pool.to_json(extended=True)})
"UserPool": user_pool.to_json(extended=True)
})
def delete_user_pool(self): def delete_user_pool(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
@ -50,41 +45,63 @@ class CognitoIdpResponse(BaseResponse):
def create_user_pool_domain(self): def create_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].create_user_pool_domain(user_pool_id, domain) custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].create_user_pool_domain(
user_pool_id, domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
if domain_description:
return json.dumps(domain_description)
return "" return ""
def describe_user_pool_domain(self): def describe_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(domain) user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(
domain
)
domain_description = {} domain_description = {}
if user_pool_domain: if user_pool_domain:
domain_description = user_pool_domain.to_json() domain_description = user_pool_domain.to_json()
return json.dumps({ return json.dumps({"DomainDescription": domain_description})
"DomainDescription": domain_description
})
def delete_user_pool_domain(self): def delete_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
cognitoidp_backends[self.region].delete_user_pool_domain(domain) cognitoidp_backends[self.region].delete_user_pool_domain(domain)
return "" return ""
def update_user_pool_domain(self):
domain = self._get_param("Domain")
custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].update_user_pool_domain(
domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
if domain_description:
return json.dumps(domain_description)
return ""
# User pool client # User pool client
def create_user_pool_client(self): def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId") user_pool_id = self.parameters.pop("UserPoolId")
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(user_pool_id, self.parameters) user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(
return json.dumps({ user_pool_id, self.parameters
"UserPoolClient": user_pool_client.to_json(extended=True) )
}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def list_user_pool_clients(self): def list_user_pool_clients(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
max_results = self._get_param("MaxResults") max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken", "0") next_token = self._get_param("NextToken", "0")
user_pool_clients, next_token = cognitoidp_backends[self.region].list_user_pool_clients(user_pool_id, user_pool_clients, next_token = cognitoidp_backends[
max_results=max_results, next_token=next_token) self.region
].list_user_pool_clients(
user_pool_id, max_results=max_results, next_token=next_token
)
response = { response = {
"UserPoolClients": [user_pool_client.to_json() for user_pool_client in user_pool_clients] "UserPoolClients": [
user_pool_client.to_json() for user_pool_client in user_pool_clients
]
} }
if next_token: if next_token:
response["NextToken"] = str(next_token) response["NextToken"] = str(next_token)
@ -93,43 +110,51 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_client(self): def describe_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(user_pool_id, client_id) user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(
return json.dumps({ user_pool_id, client_id
"UserPoolClient": user_pool_client.to_json(extended=True) )
}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def update_user_pool_client(self): def update_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId") user_pool_id = self.parameters.pop("UserPoolId")
client_id = self.parameters.pop("ClientId") client_id = self.parameters.pop("ClientId")
user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(user_pool_id, client_id, self.parameters) user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(
return json.dumps({ user_pool_id, client_id, self.parameters
"UserPoolClient": user_pool_client.to_json(extended=True) )
}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def delete_user_pool_client(self): def delete_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
cognitoidp_backends[self.region].delete_user_pool_client(user_pool_id, client_id) cognitoidp_backends[self.region].delete_user_pool_client(
user_pool_id, client_id
)
return "" return ""
# Identity provider # Identity provider
def create_identity_provider(self): def create_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self.parameters.pop("ProviderName") name = self.parameters.pop("ProviderName")
identity_provider = cognitoidp_backends[self.region].create_identity_provider(user_pool_id, name, self.parameters) identity_provider = cognitoidp_backends[self.region].create_identity_provider(
return json.dumps({ user_pool_id, name, self.parameters
"IdentityProvider": identity_provider.to_json(extended=True) )
}) return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def list_identity_providers(self): def list_identity_providers(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
max_results = self._get_param("MaxResults") max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken", "0") next_token = self._get_param("NextToken", "0")
identity_providers, next_token = cognitoidp_backends[self.region].list_identity_providers( identity_providers, next_token = cognitoidp_backends[
self.region
].list_identity_providers(
user_pool_id, max_results=max_results, next_token=next_token user_pool_id, max_results=max_results, next_token=next_token
) )
response = { response = {
"Providers": [identity_provider.to_json() for identity_provider in identity_providers] "Providers": [
identity_provider.to_json() for identity_provider in identity_providers
]
} }
if next_token: if next_token:
response["NextToken"] = str(next_token) response["NextToken"] = str(next_token)
@ -138,18 +163,22 @@ class CognitoIdpResponse(BaseResponse):
def describe_identity_provider(self): def describe_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName") name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].describe_identity_provider(user_pool_id, name) identity_provider = cognitoidp_backends[self.region].describe_identity_provider(
return json.dumps({ user_pool_id, name
"IdentityProvider": identity_provider.to_json(extended=True) )
}) return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def update_identity_provider(self): def update_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName") name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].update_identity_provider(user_pool_id, name, self.parameters) identity_provider = cognitoidp_backends[self.region].update_identity_provider(
return json.dumps({ user_pool_id, name, self.parameters
"IdentityProvider": identity_provider.to_json(extended=True) )
}) return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def delete_identity_provider(self): def delete_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
@ -166,31 +195,21 @@ class CognitoIdpResponse(BaseResponse):
precedence = self._get_param("Precedence") precedence = self._get_param("Precedence")
group = cognitoidp_backends[self.region].create_group( group = cognitoidp_backends[self.region].create_group(
user_pool_id, user_pool_id, group_name, description, role_arn, precedence
group_name,
description,
role_arn,
precedence,
) )
return json.dumps({ return json.dumps({"Group": group.to_json()})
"Group": group.to_json(),
})
def get_group(self): def get_group(self):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name) group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name)
return json.dumps({ return json.dumps({"Group": group.to_json()})
"Group": group.to_json(),
})
def list_groups(self): def list_groups(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].list_groups(user_pool_id) groups = cognitoidp_backends[self.region].list_groups(user_pool_id)
return json.dumps({ return json.dumps({"Groups": [group.to_json() for group in groups]})
"Groups": [group.to_json() for group in groups],
})
def delete_group(self): def delete_group(self):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
@ -204,9 +223,7 @@ class CognitoIdpResponse(BaseResponse):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_add_user_to_group( cognitoidp_backends[self.region].admin_add_user_to_group(
user_pool_id, user_pool_id, group_name, username
group_name,
username,
) )
return "" return ""
@ -214,18 +231,18 @@ class CognitoIdpResponse(BaseResponse):
def list_users_in_group(self): def list_users_in_group(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
users = cognitoidp_backends[self.region].list_users_in_group(user_pool_id, group_name) users = cognitoidp_backends[self.region].list_users_in_group(
return json.dumps({ user_pool_id, group_name
"Users": [user.to_json(extended=True) for user in users], )
}) return json.dumps({"Users": [user.to_json(extended=True) for user in users]})
def admin_list_groups_for_user(self): def admin_list_groups_for_user(self):
username = self._get_param("Username") username = self._get_param("Username")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].admin_list_groups_for_user(user_pool_id, username) groups = cognitoidp_backends[self.region].admin_list_groups_for_user(
return json.dumps({ user_pool_id, username
"Groups": [group.to_json() for group in groups], )
}) return json.dumps({"Groups": [group.to_json() for group in groups]})
def admin_remove_user_from_group(self): def admin_remove_user_from_group(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
@ -233,9 +250,7 @@ class CognitoIdpResponse(BaseResponse):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_remove_user_from_group( cognitoidp_backends[self.region].admin_remove_user_from_group(
user_pool_id, user_pool_id, group_name, username
group_name,
username,
) )
return "" return ""
@ -249,28 +264,24 @@ class CognitoIdpResponse(BaseResponse):
user_pool_id, user_pool_id,
username, username,
temporary_password, temporary_password,
self._get_param("UserAttributes", []) self._get_param("UserAttributes", []),
) )
return json.dumps({ return json.dumps({"User": user.to_json(extended=True)})
"User": user.to_json(extended=True)
})
def admin_get_user(self): def admin_get_user(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username) user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username)
return json.dumps( return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes"))
user.to_json(extended=True, attributes_key="UserAttributes")
)
def list_users(self): def list_users(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
limit = self._get_param("Limit") limit = self._get_param("Limit")
token = self._get_param("PaginationToken") token = self._get_param("PaginationToken")
users, token = cognitoidp_backends[self.region].list_users(user_pool_id, users, token = cognitoidp_backends[self.region].list_users(
limit=limit, user_pool_id, limit=limit, pagination_token=token
pagination_token=token) )
response = {"Users": [user.to_json(extended=True) for user in users]} response = {"Users": [user.to_json(extended=True) for user in users]}
if token: if token:
response["PaginationToken"] = str(token) response["PaginationToken"] = str(token)
@ -301,10 +312,7 @@ class CognitoIdpResponse(BaseResponse):
auth_parameters = self._get_param("AuthParameters") auth_parameters = self._get_param("AuthParameters")
auth_result = cognitoidp_backends[self.region].admin_initiate_auth( auth_result = cognitoidp_backends[self.region].admin_initiate_auth(
user_pool_id, user_pool_id, client_id, auth_flow, auth_parameters
client_id,
auth_flow,
auth_parameters,
) )
return json.dumps(auth_result) return json.dumps(auth_result)
@ -315,21 +323,15 @@ class CognitoIdpResponse(BaseResponse):
challenge_name = self._get_param("ChallengeName") challenge_name = self._get_param("ChallengeName")
challenge_responses = self._get_param("ChallengeResponses") challenge_responses = self._get_param("ChallengeResponses")
auth_result = cognitoidp_backends[self.region].respond_to_auth_challenge( auth_result = cognitoidp_backends[self.region].respond_to_auth_challenge(
session, session, client_id, challenge_name, challenge_responses
client_id,
challenge_name,
challenge_responses,
) )
return json.dumps(auth_result) return json.dumps(auth_result)
def forgot_password(self): def forgot_password(self):
return json.dumps({ return json.dumps(
"CodeDeliveryDetails": { {"CodeDeliveryDetails": {"DeliveryMedium": "EMAIL", "Destination": "..."}}
"DeliveryMedium": "EMAIL", )
"Destination": "...",
}
})
# This endpoint receives no authorization header, so if moto-server is listening # This endpoint receives no authorization header, so if moto-server is listening
# on localhost (doesn't get a region in the host header), it doesn't know what # on localhost (doesn't get a region in the host header), it doesn't know what
@ -340,7 +342,9 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username") username = self._get_param("Username")
password = self._get_param("Password") password = self._get_param("Password")
region = find_region_by_value("client_id", client_id) region = find_region_by_value("client_id", client_id)
cognitoidp_backends[region].confirm_forgot_password(client_id, username, password) cognitoidp_backends[region].confirm_forgot_password(
client_id, username, password
)
return "" return ""
# Ditto the comment on confirm_forgot_password. # Ditto the comment on confirm_forgot_password.
@ -349,14 +353,26 @@ class CognitoIdpResponse(BaseResponse):
previous_password = self._get_param("PreviousPassword") previous_password = self._get_param("PreviousPassword")
proposed_password = self._get_param("ProposedPassword") proposed_password = self._get_param("ProposedPassword")
region = find_region_by_value("access_token", access_token) region = find_region_by_value("access_token", access_token)
cognitoidp_backends[region].change_password(access_token, previous_password, proposed_password) cognitoidp_backends[region].change_password(
access_token, previous_password, proposed_password
)
return ""
def admin_update_user_attributes(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
attributes = self._get_param("UserAttributes")
cognitoidp_backends[self.region].admin_update_user_attributes(
user_pool_id, username, attributes
)
return "" return ""
class CognitoIdpJsonWebKeyResponse(BaseResponse): class CognitoIdpJsonWebKeyResponse(BaseResponse):
def __init__(self): def __init__(self):
with open(os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")) as f: with open(
os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")
) as f:
self.json_web_key = f.read() self.json_web_key = f.read()
def serve_json_web_key(self, request, full_url, headers): def serve_json_web_key(self, request, full_url, headers):

View File

@ -1,11 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import CognitoIdpResponse, CognitoIdpJsonWebKeyResponse from .responses import CognitoIdpResponse, CognitoIdpJsonWebKeyResponse
url_bases = [ url_bases = ["https?://cognito-idp.(.+).amazonaws.com"]
"https?://cognito-idp.(.+).amazonaws.com",
]
url_paths = { url_paths = {
'{0}/$': CognitoIdpResponse.dispatch, "{0}/$": CognitoIdpResponse.dispatch,
'{0}/<user_pool_id>/.well-known/jwks.json$': CognitoIdpJsonWebKeyResponse().serve_json_web_key, "{0}/<user_pool_id>/.well-known/jwks.json$": CognitoIdpJsonWebKeyResponse().serve_json_web_key,
} }

View File

@ -1,5 +1,10 @@
try: try:
from collections import OrderedDict # flake8: noqa from collections import OrderedDict # noqa
except ImportError: except ImportError:
# python 2.6 or earlier, use backport # python 2.6 or earlier, use backport
from ordereddict import OrderedDict # flake8: noqa from ordereddict import OrderedDict # noqa
try:
import collections.abc as collections_abc # noqa
except ImportError:
import collections as collections_abc # noqa

View File

@ -6,8 +6,12 @@ class NameTooLongException(JsonRESTError):
code = 400 code = 400
def __init__(self, name, location): def __init__(self, name, location):
message = '1 validation error detected: Value \'{name}\' at \'{location}\' failed to satisfy' \ message = (
' constraint: Member must have length less than or equal to 256'.format(name=name, location=location) "1 validation error detected: Value '{name}' at '{location}' failed to satisfy"
" constraint: Member must have length less than or equal to 256".format(
name=name, location=location
)
)
super(NameTooLongException, self).__init__("ValidationException", message) super(NameTooLongException, self).__init__("ValidationException", message)
@ -15,135 +19,350 @@ class InvalidConfigurationRecorderNameException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'The configuration recorder name \'{name}\' is not valid, blank string.'.format(name=name) message = "The configuration recorder name '{name}' is not valid, blank string.".format(
super(InvalidConfigurationRecorderNameException, self).__init__("InvalidConfigurationRecorderNameException", name=name
message) )
super(InvalidConfigurationRecorderNameException, self).__init__(
"InvalidConfigurationRecorderNameException", message
)
class MaxNumberOfConfigurationRecordersExceededException(JsonRESTError): class MaxNumberOfConfigurationRecordersExceededException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Failed to put configuration recorder \'{name}\' because the maximum number of ' \ message = (
'configuration recorders: 1 is reached.'.format(name=name) "Failed to put configuration recorder '{name}' because the maximum number of "
"configuration recorders: 1 is reached.".format(name=name)
)
super(MaxNumberOfConfigurationRecordersExceededException, self).__init__( super(MaxNumberOfConfigurationRecordersExceededException, self).__init__(
"MaxNumberOfConfigurationRecordersExceededException", message) "MaxNumberOfConfigurationRecordersExceededException", message
)
class InvalidRecordingGroupException(JsonRESTError): class InvalidRecordingGroupException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'The recording group provided is not valid' message = "The recording group provided is not valid"
super(InvalidRecordingGroupException, self).__init__("InvalidRecordingGroupException", message) super(InvalidRecordingGroupException, self).__init__(
"InvalidRecordingGroupException", message
)
class InvalidResourceTypeException(JsonRESTError): class InvalidResourceTypeException(JsonRESTError):
code = 400 code = 400
def __init__(self, bad_list, good_list): def __init__(self, bad_list, good_list):
message = '{num} validation error detected: Value \'{bad_list}\' at ' \ message = (
'\'configurationRecorder.recordingGroup.resourceTypes\' failed to satisfy constraint: ' \ "{num} validation error detected: Value '{bad_list}' at "
'Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]'.format( "'configurationRecorder.recordingGroup.resourceTypes' failed to satisfy constraint: "
num=len(bad_list), bad_list=bad_list, good_list=good_list) "Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]".format(
num=len(bad_list), bad_list=bad_list, good_list=good_list
)
)
# For PY2: # For PY2:
message = str(message) message = str(message)
super(InvalidResourceTypeException, self).__init__("ValidationException", message) super(InvalidResourceTypeException, self).__init__(
"ValidationException", message
)
class NoSuchConfigurationAggregatorException(JsonRESTError):
code = 400
def __init__(self, number=1):
if number == 1:
message = "The configuration aggregator does not exist. Check the configuration aggregator name and try again."
else:
message = (
"At least one of the configuration aggregators does not exist. Check the configuration aggregator"
" names and try again."
)
super(NoSuchConfigurationAggregatorException, self).__init__(
"NoSuchConfigurationAggregatorException", message
)
class NoSuchConfigurationRecorderException(JsonRESTError): class NoSuchConfigurationRecorderException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Cannot find configuration recorder with the specified name \'{name}\'.'.format(name=name) message = "Cannot find configuration recorder with the specified name '{name}'.".format(
super(NoSuchConfigurationRecorderException, self).__init__("NoSuchConfigurationRecorderException", message) name=name
)
super(NoSuchConfigurationRecorderException, self).__init__(
"NoSuchConfigurationRecorderException", message
)
class InvalidDeliveryChannelNameException(JsonRESTError): class InvalidDeliveryChannelNameException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'The delivery channel name \'{name}\' is not valid, blank string.'.format(name=name) message = "The delivery channel name '{name}' is not valid, blank string.".format(
super(InvalidDeliveryChannelNameException, self).__init__("InvalidDeliveryChannelNameException", name=name
message) )
super(InvalidDeliveryChannelNameException, self).__init__(
"InvalidDeliveryChannelNameException", message
)
class NoSuchBucketException(JsonRESTError): class NoSuchBucketException(JsonRESTError):
"""We are *only* validating that there is value that is not '' here.""" """We are *only* validating that there is value that is not '' here."""
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'Cannot find a S3 bucket with an empty bucket name.' message = "Cannot find a S3 bucket with an empty bucket name."
super(NoSuchBucketException, self).__init__("NoSuchBucketException", message) super(NoSuchBucketException, self).__init__("NoSuchBucketException", message)
class InvalidNextTokenException(JsonRESTError):
code = 400
def __init__(self):
message = "The nextToken provided is invalid"
super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", message
)
class InvalidS3KeyPrefixException(JsonRESTError): class InvalidS3KeyPrefixException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'The s3 key prefix \'\' is not valid, empty s3 key prefix.' message = "The s3 key prefix '' is not valid, empty s3 key prefix."
super(InvalidS3KeyPrefixException, self).__init__("InvalidS3KeyPrefixException", message) super(InvalidS3KeyPrefixException, self).__init__(
"InvalidS3KeyPrefixException", message
)
class InvalidSNSTopicARNException(JsonRESTError): class InvalidSNSTopicARNException(JsonRESTError):
"""We are *only* validating that there is value that is not '' here.""" """We are *only* validating that there is value that is not '' here."""
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'The sns topic arn \'\' is not valid.' message = "The sns topic arn '' is not valid."
super(InvalidSNSTopicARNException, self).__init__("InvalidSNSTopicARNException", message) super(InvalidSNSTopicARNException, self).__init__(
"InvalidSNSTopicARNException", message
)
class InvalidDeliveryFrequency(JsonRESTError): class InvalidDeliveryFrequency(JsonRESTError):
code = 400 code = 400
def __init__(self, value, good_list): def __init__(self, value, good_list):
message = '1 validation error detected: Value \'{value}\' at ' \ message = (
'\'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency\' failed to satisfy ' \ "1 validation error detected: Value '{value}' at "
'constraint: Member must satisfy enum value set: {good_list}'.format(value=value, good_list=good_list) "'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency' failed to satisfy "
super(InvalidDeliveryFrequency, self).__init__("InvalidDeliveryFrequency", message) "constraint: Member must satisfy enum value set: {good_list}".format(
value=value, good_list=good_list
)
)
super(InvalidDeliveryFrequency, self).__init__(
"InvalidDeliveryFrequency", message
)
class MaxNumberOfDeliveryChannelsExceededException(JsonRESTError): class MaxNumberOfDeliveryChannelsExceededException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Failed to put delivery channel \'{name}\' because the maximum number of ' \ message = (
'delivery channels: 1 is reached.'.format(name=name) "Failed to put delivery channel '{name}' because the maximum number of "
"delivery channels: 1 is reached.".format(name=name)
)
super(MaxNumberOfDeliveryChannelsExceededException, self).__init__( super(MaxNumberOfDeliveryChannelsExceededException, self).__init__(
"MaxNumberOfDeliveryChannelsExceededException", message) "MaxNumberOfDeliveryChannelsExceededException", message
)
class NoSuchDeliveryChannelException(JsonRESTError): class NoSuchDeliveryChannelException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Cannot find delivery channel with specified name \'{name}\'.'.format(name=name) message = "Cannot find delivery channel with specified name '{name}'.".format(
super(NoSuchDeliveryChannelException, self).__init__("NoSuchDeliveryChannelException", message) name=name
)
super(NoSuchDeliveryChannelException, self).__init__(
"NoSuchDeliveryChannelException", message
)
class NoAvailableConfigurationRecorderException(JsonRESTError): class NoAvailableConfigurationRecorderException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'Configuration recorder is not available to put delivery channel.' message = "Configuration recorder is not available to put delivery channel."
super(NoAvailableConfigurationRecorderException, self).__init__("NoAvailableConfigurationRecorderException", super(NoAvailableConfigurationRecorderException, self).__init__(
message) "NoAvailableConfigurationRecorderException", message
)
class NoAvailableDeliveryChannelException(JsonRESTError): class NoAvailableDeliveryChannelException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'Delivery channel is not available to start configuration recorder.' message = "Delivery channel is not available to start configuration recorder."
super(NoAvailableDeliveryChannelException, self).__init__("NoAvailableDeliveryChannelException", message) super(NoAvailableDeliveryChannelException, self).__init__(
"NoAvailableDeliveryChannelException", message
)
class LastDeliveryChannelDeleteFailedException(JsonRESTError): class LastDeliveryChannelDeleteFailedException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \ message = (
'because there is a running configuration recorder.'.format(name=name) "Failed to delete last specified delivery channel with name '{name}', because there, "
super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message) "because there is a running configuration recorder.".format(name=name)
)
super(LastDeliveryChannelDeleteFailedException, self).__init__(
"LastDeliveryChannelDeleteFailedException", message
)
class TooManyAccountSources(JsonRESTError):
code = 400
def __init__(self, length):
locations = ["com.amazonaws.xyz"] * length
message = (
"Value '[{locations}]' at 'accountAggregationSources' failed to satisfy constraint: "
"Member must have length less than or equal to 1".format(
locations=", ".join(locations)
)
)
super(TooManyAccountSources, self).__init__("ValidationException", message)
class DuplicateTags(JsonRESTError):
code = 400
def __init__(self):
super(DuplicateTags, self).__init__(
"InvalidInput",
"Duplicate tag keys found. Please note that Tag keys are case insensitive.",
)
class TagKeyTooBig(JsonRESTError):
code = 400
def __init__(self, tag, param="tags.X.member.key"):
super(TagKeyTooBig, self).__init__(
"ValidationException",
"1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 128".format(
tag, param
),
)
class TagValueTooBig(JsonRESTError):
code = 400
def __init__(self, tag):
super(TagValueTooBig, self).__init__(
"ValidationException",
"1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy "
"constraint: Member must have length less than or equal to 256".format(tag),
)
class InvalidParameterValueException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValueException, self).__init__(
"InvalidParameterValueException", message
)
class InvalidTagCharacters(JsonRESTError):
code = 400
def __init__(self, tag, param="tags.X.member.key"):
message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(
tag, param
)
message += "constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+"
super(InvalidTagCharacters, self).__init__("ValidationException", message)
class TooManyTags(JsonRESTError):
code = 400
def __init__(self, tags, param="tags"):
super(TooManyTags, self).__init__(
"ValidationException",
"1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 50.".format(
tags, param
),
)
class InvalidResourceParameters(JsonRESTError):
code = 400
def __init__(self):
super(InvalidResourceParameters, self).__init__(
"ValidationException",
"Both Resource ID and Resource Name " "cannot be specified in the request",
)
class InvalidLimit(JsonRESTError):
code = 400
def __init__(self, value):
super(InvalidLimit, self).__init__(
"ValidationException",
"Value '{value}' at 'limit' failed to satisify constraint: Member"
" must have value less than or equal to 100".format(value=value),
)
class TooManyResourceIds(JsonRESTError):
code = 400
def __init__(self):
super(TooManyResourceIds, self).__init__(
"ValidationException",
"The specified list had more than 20 resource ID's. "
"It must have '20' or less items",
)
class ResourceNotDiscoveredException(JsonRESTError):
code = 400
def __init__(self, type, resource):
super(ResourceNotDiscoveredException, self).__init__(
"ResourceNotDiscoveredException",
"Resource {resource} of resourceType:{type} is unknown or has not been "
"discovered".format(resource=resource, type=type),
)
class TooManyResourceKeys(JsonRESTError):
code = 400
def __init__(self, bad_list):
message = (
"1 validation error detected: Value '{bad_list}' at "
"'resourceKeys' failed to satisfy constraint: "
"Member must have length less than or equal to 100".format(
bad_list=bad_list
)
)
# For PY2:
message = str(message)
super(TooManyResourceKeys, self).__init__("ValidationException", message)

File diff suppressed because it is too large Load Diff

View File

@ -4,50 +4,150 @@ from .models import config_backends
class ConfigResponse(BaseResponse): class ConfigResponse(BaseResponse):
@property @property
def config_backend(self): def config_backend(self):
return config_backends[self.region] return config_backends[self.region]
def put_configuration_recorder(self): def put_configuration_recorder(self):
self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder')) self.config_backend.put_configuration_recorder(
self._get_param("ConfigurationRecorder")
)
return ""
def put_configuration_aggregator(self):
aggregator = self.config_backend.put_configuration_aggregator(
json.loads(self.body), self.region
)
schema = {"ConfigurationAggregator": aggregator}
return json.dumps(schema)
def describe_configuration_aggregators(self):
aggregators = self.config_backend.describe_configuration_aggregators(
self._get_param("ConfigurationAggregatorNames"),
self._get_param("NextToken"),
self._get_param("Limit"),
)
return json.dumps(aggregators)
def delete_configuration_aggregator(self):
self.config_backend.delete_configuration_aggregator(
self._get_param("ConfigurationAggregatorName")
)
return ""
def put_aggregation_authorization(self):
agg_auth = self.config_backend.put_aggregation_authorization(
self.region,
self._get_param("AuthorizedAccountId"),
self._get_param("AuthorizedAwsRegion"),
self._get_param("Tags"),
)
schema = {"AggregationAuthorization": agg_auth}
return json.dumps(schema)
def describe_aggregation_authorizations(self):
authorizations = self.config_backend.describe_aggregation_authorizations(
self._get_param("NextToken"), self._get_param("Limit")
)
return json.dumps(authorizations)
def delete_aggregation_authorization(self):
self.config_backend.delete_aggregation_authorization(
self._get_param("AuthorizedAccountId"),
self._get_param("AuthorizedAwsRegion"),
)
return "" return ""
def describe_configuration_recorders(self): def describe_configuration_recorders(self):
recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames')) recorders = self.config_backend.describe_configuration_recorders(
schema = {'ConfigurationRecorders': recorders} self._get_param("ConfigurationRecorderNames")
)
schema = {"ConfigurationRecorders": recorders}
return json.dumps(schema) return json.dumps(schema)
def describe_configuration_recorder_status(self): def describe_configuration_recorder_status(self):
recorder_statuses = self.config_backend.describe_configuration_recorder_status( recorder_statuses = self.config_backend.describe_configuration_recorder_status(
self._get_param('ConfigurationRecorderNames')) self._get_param("ConfigurationRecorderNames")
schema = {'ConfigurationRecordersStatus': recorder_statuses} )
schema = {"ConfigurationRecordersStatus": recorder_statuses}
return json.dumps(schema) return json.dumps(schema)
def put_delivery_channel(self): def put_delivery_channel(self):
self.config_backend.put_delivery_channel(self._get_param('DeliveryChannel')) self.config_backend.put_delivery_channel(self._get_param("DeliveryChannel"))
return "" return ""
def describe_delivery_channels(self): def describe_delivery_channels(self):
delivery_channels = self.config_backend.describe_delivery_channels(self._get_param('DeliveryChannelNames')) delivery_channels = self.config_backend.describe_delivery_channels(
schema = {'DeliveryChannels': delivery_channels} self._get_param("DeliveryChannelNames")
)
schema = {"DeliveryChannels": delivery_channels}
return json.dumps(schema) return json.dumps(schema)
def describe_delivery_channel_status(self): def describe_delivery_channel_status(self):
raise NotImplementedError() raise NotImplementedError()
def delete_delivery_channel(self): def delete_delivery_channel(self):
self.config_backend.delete_delivery_channel(self._get_param('DeliveryChannelName')) self.config_backend.delete_delivery_channel(
self._get_param("DeliveryChannelName")
)
return "" return ""
def delete_configuration_recorder(self): def delete_configuration_recorder(self):
self.config_backend.delete_configuration_recorder(self._get_param('ConfigurationRecorderName')) self.config_backend.delete_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return "" return ""
def start_configuration_recorder(self): def start_configuration_recorder(self):
self.config_backend.start_configuration_recorder(self._get_param('ConfigurationRecorderName')) self.config_backend.start_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return "" return ""
def stop_configuration_recorder(self): def stop_configuration_recorder(self):
self.config_backend.stop_configuration_recorder(self._get_param('ConfigurationRecorderName')) self.config_backend.stop_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return "" return ""
def list_discovered_resources(self):
schema = self.config_backend.list_discovered_resources(
self._get_param("resourceType"),
self.region,
self._get_param("resourceIds"),
self._get_param("resourceName"),
self._get_param("limit"),
self._get_param("nextToken"),
)
return json.dumps(schema)
def list_aggregate_discovered_resources(self):
schema = self.config_backend.list_aggregate_discovered_resources(
self._get_param("ConfigurationAggregatorName"),
self._get_param("ResourceType"),
self._get_param("Filters"),
self._get_param("Limit"),
self._get_param("NextToken"),
)
return json.dumps(schema)
def get_resource_config_history(self):
schema = self.config_backend.get_resource_config_history(
self._get_param("resourceType"), self._get_param("resourceId"), self.region
)
return json.dumps(schema)
def batch_get_resource_config(self):
schema = self.config_backend.batch_get_resource_config(
self._get_param("resourceKeys"), self.region
)
return json.dumps(schema)
def batch_get_aggregate_resource_config(self):
schema = self.config_backend.batch_get_aggregate_resource_config(
self._get_param("ConfigurationAggregatorName"),
self._get_param("ResourceIdentifiers"),
)
return json.dumps(schema)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import ConfigResponse from .responses import ConfigResponse
url_bases = [ url_bases = ["https?://config.(.+).amazonaws.com"]
"https?://config.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": ConfigResponse.dispatch}
'{0}/$': ConfigResponse.dispatch,
}

View File

@ -1,4 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import BaseModel, BaseBackend, moto_api_backend # flake8: noqa
from .models import BaseModel, BaseBackend, moto_api_backend, ACCOUNT_ID # noqa
from .responses import ActionAuthenticatorMixin
moto_api_backends = {"global": moto_api_backend} moto_api_backends = {"global": moto_api_backend}
set_initial_no_auth_action_count = (
ActionAuthenticatorMixin.set_initial_no_auth_action_count
)

404
moto/core/access_control.py Normal file
View File

@ -0,0 +1,404 @@
"""
This implementation is NOT complete, there are many things to improve.
The following is a list of the most important missing features and inaccuracies.
TODO add support for more principals, apart from IAM users and assumed IAM roles
TODO add support for the Resource and Condition parts of IAM policies
TODO add support and create tests for all services in moto (for example, API Gateway is probably not supported currently)
TODO implement service specific error messages (currently, EC2 and S3 are supported separately, everything else defaults to the errors IAM returns)
TODO include information about the action's resource in error messages (once the Resource element in IAM policies is supported)
TODO check all other actions that are performed by the action called by the user (for example, autoscaling:CreateAutoScalingGroup requires permission for iam:CreateServiceLinkedRole too - see https://docs.aws.amazon.com/autoscaling/ec2/userguide/control-access-using-iam.html)
TODO add support for resource-based policies
"""
import json
import logging
import re
from abc import abstractmethod, ABCMeta
from enum import Enum
import six
from botocore.auth import SigV4Auth, S3SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
from six import string_types
from moto.core import ACCOUNT_ID
from moto.iam.models import Policy
from moto.iam import iam_backend
from moto.core.exceptions import (
SignatureDoesNotMatchError,
AccessDeniedError,
InvalidClientTokenIdError,
AuthFailureError,
)
from moto.s3.exceptions import (
BucketAccessDeniedError,
S3AccessDeniedError,
BucketInvalidTokenError,
S3InvalidTokenError,
S3InvalidAccessKeyIdError,
BucketInvalidAccessKeyIdError,
BucketSignatureDoesNotMatchError,
S3SignatureDoesNotMatchError,
)
from moto.sts import sts_backend
log = logging.getLogger(__name__)
def create_access_key(access_key_id, headers):
if access_key_id.startswith("AKIA") or "X-Amz-Security-Token" not in headers:
return IAMUserAccessKey(access_key_id, headers)
else:
return AssumedRoleAccessKey(access_key_id, headers)
class IAMUserAccessKey(object):
def __init__(self, access_key_id, headers):
iam_users = iam_backend.list_users("/", None, None)
for iam_user in iam_users:
for access_key in iam_user.access_keys:
if access_key.access_key_id == access_key_id:
self._owner_user_name = iam_user.name
self._access_key_id = access_key_id
self._secret_access_key = access_key.secret_access_key
if "X-Amz-Security-Token" in headers:
raise CreateAccessKeyFailure(reason="InvalidToken")
return
raise CreateAccessKeyFailure(reason="InvalidId")
@property
def arn(self):
return "arn:aws:iam::{account_id}:user/{iam_user_name}".format(
account_id=ACCOUNT_ID, iam_user_name=self._owner_user_name
)
def create_credentials(self):
return Credentials(self._access_key_id, self._secret_access_key)
def collect_policies(self):
user_policies = []
inline_policy_names = iam_backend.list_user_policies(self._owner_user_name)
for inline_policy_name in inline_policy_names:
inline_policy = iam_backend.get_user_policy(
self._owner_user_name, inline_policy_name
)
user_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_user_policies(
self._owner_user_name
)
user_policies += attached_policies
user_groups = iam_backend.get_groups_for_user(self._owner_user_name)
for user_group in user_groups:
inline_group_policy_names = iam_backend.list_group_policies(user_group.name)
for inline_group_policy_name in inline_group_policy_names:
inline_user_group_policy = iam_backend.get_group_policy(
user_group.name, inline_group_policy_name
)
user_policies.append(inline_user_group_policy)
attached_group_policies, _ = iam_backend.list_attached_group_policies(
user_group.name
)
user_policies += attached_group_policies
return user_policies
class AssumedRoleAccessKey(object):
def __init__(self, access_key_id, headers):
for assumed_role in sts_backend.assumed_roles:
if assumed_role.access_key_id == access_key_id:
self._access_key_id = access_key_id
self._secret_access_key = assumed_role.secret_access_key
self._session_token = assumed_role.session_token
self._owner_role_name = assumed_role.role_arn.split("/")[-1]
self._session_name = assumed_role.session_name
if headers["X-Amz-Security-Token"] != self._session_token:
raise CreateAccessKeyFailure(reason="InvalidToken")
return
raise CreateAccessKeyFailure(reason="InvalidId")
@property
def arn(self):
return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format(
account_id=ACCOUNT_ID,
role_name=self._owner_role_name,
session_name=self._session_name,
)
def create_credentials(self):
return Credentials(
self._access_key_id, self._secret_access_key, self._session_token
)
def collect_policies(self):
role_policies = []
inline_policy_names = iam_backend.list_role_policies(self._owner_role_name)
for inline_policy_name in inline_policy_names:
_, inline_policy = iam_backend.get_role_policy(
self._owner_role_name, inline_policy_name
)
role_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_role_policies(
self._owner_role_name
)
role_policies += attached_policies
return role_policies
class CreateAccessKeyFailure(Exception):
def __init__(self, reason, *args):
super(CreateAccessKeyFailure, self).__init__(*args)
self.reason = reason
@six.add_metaclass(ABCMeta)
class IAMRequestBase(object):
def __init__(self, method, path, data, headers):
log.debug(
"Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format(
class_name=self.__class__.__name__,
method=method,
path=path,
data=data,
headers=headers,
)
)
self._method = method
self._path = path
self._data = data
self._headers = headers
credential_scope = self._get_string_between(
"Credential=", ",", self._headers["Authorization"]
)
credential_data = credential_scope.split("/")
self._region = credential_data[2]
self._service = credential_data[3]
self._action = (
self._service
+ ":"
+ (
self._data["Action"][0]
if isinstance(self._data["Action"], list)
else self._data["Action"]
)
)
try:
self._access_key = create_access_key(
access_key_id=credential_data[0], headers=headers
)
except CreateAccessKeyFailure as e:
self._raise_invalid_access_key(e.reason)
def check_signature(self):
original_signature = self._get_string_between(
"Signature=", ",", self._headers["Authorization"]
)
calculated_signature = self._calculate_signature()
if original_signature != calculated_signature:
self._raise_signature_does_not_match()
def check_action_permitted(self):
if (
self._action == "sts:GetCallerIdentity"
): # always allowed, even if there's an explicit Deny for it
return True
policies = self._access_key.collect_policies()
permitted = False
for policy in policies:
iam_policy = IAMPolicy(policy)
permission_result = iam_policy.is_action_permitted(self._action)
if permission_result == PermissionResult.DENIED:
self._raise_access_denied()
elif permission_result == PermissionResult.PERMITTED:
permitted = True
if not permitted:
self._raise_access_denied()
@abstractmethod
def _raise_signature_does_not_match(self):
raise NotImplementedError()
@abstractmethod
def _raise_access_denied(self):
raise NotImplementedError()
@abstractmethod
def _raise_invalid_access_key(self, reason):
raise NotImplementedError()
@abstractmethod
def _create_auth(self, credentials):
raise NotImplementedError()
@staticmethod
def _create_headers_for_aws_request(signed_headers, original_headers):
headers = {}
for key, value in original_headers.items():
if key.lower() in signed_headers:
headers[key] = value
return headers
def _create_aws_request(self):
signed_headers = self._get_string_between(
"SignedHeaders=", ",", self._headers["Authorization"]
).split(";")
headers = self._create_headers_for_aws_request(signed_headers, self._headers)
request = AWSRequest(
method=self._method, url=self._path, data=self._data, headers=headers
)
request.context["timestamp"] = headers["X-Amz-Date"]
return request
def _calculate_signature(self):
credentials = self._access_key.create_credentials()
auth = self._create_auth(credentials)
request = self._create_aws_request()
canonical_request = auth.canonical_request(request)
string_to_sign = auth.string_to_sign(request, canonical_request)
return auth.signature(string_to_sign, request)
@staticmethod
def _get_string_between(first_separator, second_separator, string):
return string.partition(first_separator)[2].partition(second_separator)[0]
class IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self):
if self._service == "ec2":
raise AuthFailureError()
else:
raise SignatureDoesNotMatchError()
def _raise_invalid_access_key(self, _):
if self._service == "ec2":
raise AuthFailureError()
else:
raise InvalidClientTokenIdError()
def _create_auth(self, credentials):
return SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self):
raise AccessDeniedError(user_arn=self._access_key.arn, action=self._action)
class S3IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self):
if "BucketName" in self._data:
raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"])
else:
raise S3SignatureDoesNotMatchError()
def _raise_invalid_access_key(self, reason):
if reason == "InvalidToken":
if "BucketName" in self._data:
raise BucketInvalidTokenError(bucket=self._data["BucketName"])
else:
raise S3InvalidTokenError()
else:
if "BucketName" in self._data:
raise BucketInvalidAccessKeyIdError(bucket=self._data["BucketName"])
else:
raise S3InvalidAccessKeyIdError()
def _create_auth(self, credentials):
return S3SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self):
if "BucketName" in self._data:
raise BucketAccessDeniedError(bucket=self._data["BucketName"])
else:
raise S3AccessDeniedError()
class IAMPolicy(object):
def __init__(self, policy):
if isinstance(policy, Policy):
default_version = next(
policy_version
for policy_version in policy.versions
if policy_version.is_default
)
policy_document = default_version.document
elif isinstance(policy, string_types):
policy_document = policy
else:
policy_document = policy["policy_document"]
self._policy_json = json.loads(policy_document)
def is_action_permitted(self, action):
permitted = False
if isinstance(self._policy_json["Statement"], list):
for policy_statement in self._policy_json["Statement"]:
iam_policy_statement = IAMPolicyStatement(policy_statement)
permission_result = iam_policy_statement.is_action_permitted(action)
if permission_result == PermissionResult.DENIED:
return permission_result
elif permission_result == PermissionResult.PERMITTED:
permitted = True
else: # dict
iam_policy_statement = IAMPolicyStatement(self._policy_json["Statement"])
return iam_policy_statement.is_action_permitted(action)
if permitted:
return PermissionResult.PERMITTED
else:
return PermissionResult.NEUTRAL
class IAMPolicyStatement(object):
def __init__(self, statement):
self._statement = statement
def is_action_permitted(self, action):
is_action_concerned = False
if "NotAction" in self._statement:
if not self._check_element_matches("NotAction", action):
is_action_concerned = True
else: # Action is present
if self._check_element_matches("Action", action):
is_action_concerned = True
if is_action_concerned:
if self._statement["Effect"] == "Allow":
return PermissionResult.PERMITTED
else: # Deny
return PermissionResult.DENIED
else:
return PermissionResult.NEUTRAL
def _check_element_matches(self, statement_element, value):
if isinstance(self._statement[statement_element], list):
for statement_element_value in self._statement[statement_element]:
if self._match(statement_element_value, value):
return True
return False
else: # string
return self._match(self._statement[statement_element], value)
@staticmethod
def _match(pattern, string):
pattern = pattern.replace("*", ".*")
pattern = "^{pattern}$".format(pattern=pattern)
return re.match(pattern, string)
class PermissionResult(Enum):
PERMITTED = 1
DENIED = 2
NEUTRAL = 3

View File

@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment from jinja2 import DictLoader, Environment
SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?> SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Error> <Error>
<Code>{{error_type}}</Code> <Code>{{error_type}}</Code>
<Message>{{message}}</Message> <Message>{{message}}</Message>
@ -13,8 +13,8 @@ SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error> </Error>
""" """
ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?> ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Response> <ErrorResponse>
<Errors> <Errors>
<Error> <Error>
<Code>{{error_type}}</Code> <Code>{{error_type}}</Code>
@ -23,10 +23,10 @@ ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error> </Error>
</Errors> </Errors>
<RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID> <RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID>
</Response> </ErrorResponse>
""" """
ERROR_JSON_RESPONSE = u"""{ ERROR_JSON_RESPONSE = """{
"message": "{{message}}", "message": "{{message}}",
"__type": "{{error_type}}" "__type": "{{error_type}}"
} }
@ -37,18 +37,19 @@ class RESTError(HTTPException):
code = 400 code = 400
templates = { templates = {
'single_error': SINGLE_ERROR_RESPONSE, "single_error": SINGLE_ERROR_RESPONSE,
'error': ERROR_RESPONSE, "error": ERROR_RESPONSE,
'error_json': ERROR_JSON_RESPONSE, "error_json": ERROR_JSON_RESPONSE,
} }
def __init__(self, error_type, message, template='error', **kwargs): def __init__(self, error_type, message, template="error", **kwargs):
super(RESTError, self).__init__() super(RESTError, self).__init__()
env = Environment(loader=DictLoader(self.templates)) env = Environment(loader=DictLoader(self.templates))
self.error_type = error_type self.error_type = error_type
self.message = message self.message = message
self.description = env.get_template(template).render( self.description = env.get_template(template).render(
error_type=error_type, message=message, **kwargs) error_type=error_type, message=message, **kwargs
)
class DryRunClientError(RESTError): class DryRunClientError(RESTError):
@ -56,12 +57,64 @@ class DryRunClientError(RESTError):
class JsonRESTError(RESTError): class JsonRESTError(RESTError):
def __init__(self, error_type, message, template='error_json', **kwargs): def __init__(self, error_type, message, template="error_json", **kwargs):
super(JsonRESTError, self).__init__( super(JsonRESTError, self).__init__(error_type, message, template, **kwargs)
error_type, message, template, **kwargs)
def get_headers(self, *args, **kwargs): def get_headers(self, *args, **kwargs):
return [('Content-Type', 'application/json')] return [("Content-Type", "application/json")]
def get_body(self, *args, **kwargs): def get_body(self, *args, **kwargs):
return self.description return self.description
class SignatureDoesNotMatchError(RESTError):
code = 403
def __init__(self):
super(SignatureDoesNotMatchError, self).__init__(
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.",
)
class InvalidClientTokenIdError(RESTError):
code = 403
def __init__(self):
super(InvalidClientTokenIdError, self).__init__(
"InvalidClientTokenId",
"The security token included in the request is invalid.",
)
class AccessDeniedError(RESTError):
code = 403
def __init__(self, user_arn, action):
super(AccessDeniedError, self).__init__(
"AccessDenied",
"User: {user_arn} is not authorized to perform: {operation}".format(
user_arn=user_arn, operation=action
),
)
class AuthFailureError(RESTError):
code = 401
def __init__(self):
super(AuthFailureError, self).__init__(
"AuthFailure",
"AWS was not able to validate the provided access credentials",
)
class InvalidNextTokenException(JsonRESTError):
"""For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core."""
code = 400
def __init__(self):
super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", "The nextToken provided is invalid"
)

View File

@ -12,6 +12,7 @@ from collections import defaultdict
from botocore.handlers import BUILTIN_HANDLERS from botocore.handlers import BUILTIN_HANDLERS
from botocore.awsrequest import AWSResponse from botocore.awsrequest import AWSResponse
import mock
from moto import settings from moto import settings
import responses import responses
from moto.packages.httpretty import HTTPretty from moto.packages.httpretty import HTTPretty
@ -22,9 +23,7 @@ from .utils import (
) )
# "Mock" the AWS credentials as they can't be mocked in Botocore currently ACCOUNT_ID = os.environ.get("MOTO_ACCOUNT_ID", "123456789012")
os.environ.setdefault("AWS_ACCESS_KEY_ID", "foobar_key")
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "foobar_secret")
class BaseMockAWS(object): class BaseMockAWS(object):
@ -35,13 +34,22 @@ class BaseMockAWS(object):
self.backends_for_urls = {} self.backends_for_urls = {}
from moto.backends import BACKENDS from moto.backends import BACKENDS
default_backends = { default_backends = {
"instance_metadata": BACKENDS['instance_metadata']['global'], "instance_metadata": BACKENDS["instance_metadata"]["global"],
"moto_api": BACKENDS['moto_api']['global'], "moto_api": BACKENDS["moto_api"]["global"],
} }
self.backends_for_urls.update(self.backends) self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends) self.backends_for_urls.update(default_backends)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
FAKE_KEYS = {
"AWS_ACCESS_KEY_ID": "foobar_key",
"AWS_SECRET_ACCESS_KEY": "foobar_secret",
}
self.default_session_mock = mock.patch("boto3.DEFAULT_SESSION", None)
self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
self.reset() self.reset()
@ -52,11 +60,15 @@ class BaseMockAWS(object):
def __enter__(self): def __enter__(self):
self.start() self.start()
return self
def __exit__(self, *args): def __exit__(self, *args):
self.stop() self.stop()
def start(self, reset=True): def start(self, reset=True):
self.default_session_mock.start()
self.env_variables_mocks.start()
self.__class__.nested_count += 1 self.__class__.nested_count += 1
if reset: if reset:
for backend in self.backends.values(): for backend in self.backends.values():
@ -65,10 +77,12 @@ class BaseMockAWS(object):
self.enable_patching() self.enable_patching()
def stop(self): def stop(self):
self.default_session_mock.stop()
self.env_variables_mocks.stop()
self.__class__.nested_count -= 1 self.__class__.nested_count -= 1
if self.__class__.nested_count < 0: if self.__class__.nested_count < 0:
raise RuntimeError('Called stop() before start().') raise RuntimeError("Called stop() before start().")
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
self.disable_patching() self.disable_patching()
@ -81,6 +95,7 @@ class BaseMockAWS(object):
finally: finally:
self.stop() self.stop()
return result return result
functools.update_wrapper(wrapper, func) functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func wrapper.__wrapped__ = func
return wrapper return wrapper
@ -118,7 +133,6 @@ class BaseMockAWS(object):
class HttprettyMockAWS(BaseMockAWS): class HttprettyMockAWS(BaseMockAWS):
def reset(self): def reset(self):
HTTPretty.reset() HTTPretty.reset()
@ -140,18 +154,26 @@ class HttprettyMockAWS(BaseMockAWS):
HTTPretty.reset() HTTPretty.reset()
RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD, RESPONSES_METHODS = [
responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT] responses.GET,
responses.DELETE,
responses.HEAD,
responses.OPTIONS,
responses.PATCH,
responses.POST,
responses.PUT,
]
class CallbackResponse(responses.CallbackResponse): class CallbackResponse(responses.CallbackResponse):
''' """
Need to subclass so we can change a couple things Need to subclass so we can change a couple things
''' """
def get_response(self, request): def get_response(self, request):
''' """
Need to override this so we can pass decode_content=False Need to override this so we can pass decode_content=False
''' """
headers = self.get_headers() headers = self.get_headers()
result = self.callback(request) result = self.callback(request)
@ -173,17 +195,17 @@ class CallbackResponse(responses.CallbackResponse):
) )
def _url_matches(self, url, other, match_querystring=False): def _url_matches(self, url, other, match_querystring=False):
''' """
Need to override this so we can fix querystrings breaking regex matching Need to override this so we can fix querystrings breaking regex matching
''' """
if not match_querystring: if not match_querystring:
other = other.split('?', 1)[0] other = other.split("?", 1)[0]
if responses._is_string(url): if responses._is_string(url):
if responses._has_unicode(url): if responses._has_unicode(url):
url = responses._clean_unicode(url) url = responses._clean_unicode(url)
if not isinstance(other, six.text_type): if not isinstance(other, six.text_type):
other = other.encode('ascii').decode('utf8') other = other.encode("ascii").decode("utf8")
return self._url_matches_strict(url, other) return self._url_matches_strict(url, other)
elif isinstance(url, responses.Pattern) and url.match(other): elif isinstance(url, responses.Pattern) and url.match(other):
return True return True
@ -191,66 +213,23 @@ class CallbackResponse(responses.CallbackResponse):
return False return False
botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send') botocore_mock = responses.RequestsMock(
assert_all_requests_are_fired=False,
target="botocore.vendored.requests.adapters.HTTPAdapter.send",
)
responses_mock = responses._default_mock responses_mock = responses._default_mock
# Add passthrough to allow any other requests to work
# Since this uses .startswith, it applies to http and https requests.
responses_mock.add_passthru("http")
class ResponsesMockAWS(BaseMockAWS): BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
def reset(self):
botocore_mock.reset()
responses_mock.reset()
def enable_patching(self):
if not hasattr(botocore_mock, '_patcher') or not hasattr(botocore_mock._patcher, 'target'):
# Check for unactivated patcher
botocore_mock.start()
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'):
responses_mock.start()
for method in RESPONSES_METHODS:
for backend in self.backends_for_urls.values():
for key, value in backend.urls.items():
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile(key),
callback=convert_flask_to_responses_response(value),
stream=True,
match_querystring=False,
)
)
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile(key),
callback=convert_flask_to_responses_response(value),
stream=True,
match_querystring=False,
)
)
def disable_patching(self):
try:
botocore_mock.stop()
except RuntimeError:
pass
try:
responses_mock.stop()
except RuntimeError:
pass
BOTOCORE_HTTP_METHODS = [
'GET', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'
]
class MockRawResponse(BytesIO): class MockRawResponse(BytesIO):
def __init__(self, input): def __init__(self, input):
if isinstance(input, six.text_type): if isinstance(input, six.text_type):
input = input.encode('utf-8') input = input.encode("utf-8")
super(MockRawResponse, self).__init__(input) super(MockRawResponse, self).__init__(input)
def stream(self, **kwargs): def stream(self, **kwargs):
@ -281,7 +260,7 @@ class BotocoreStubber(object):
found_index = None found_index = None
matchers = self.methods.get(request.method) matchers = self.methods.get(request.method)
base_url = request.url.split('?', 1)[0] base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers): for i, (pattern, callback) in enumerate(matchers):
if pattern.match(base_url): if pattern.match(base_url):
if found_index is None: if found_index is None:
@ -294,8 +273,10 @@ class BotocoreStubber(object):
if response_callback is not None: if response_callback is not None:
for header, value in request.headers.items(): for header, value in request.headers.items():
if isinstance(value, six.binary_type): if isinstance(value, six.binary_type):
request.headers[header] = value.decode('utf-8') request.headers[header] = value.decode("utf-8")
status, headers, body = response_callback(request, request.url, request.headers) status, headers, body = response_callback(
request, request.url, request.headers
)
body = MockRawResponse(body) body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body) response = AWSResponse(request.url, status, headers, body)
@ -303,7 +284,15 @@ class BotocoreStubber(object):
botocore_stubber = BotocoreStubber() botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(('before-send', botocore_stubber)) BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
def not_implemented_callback(request):
status = 400
headers = {}
response = "The method is not implemented"
return status, headers, response
class BotocoreEventMockAWS(BaseMockAWS): class BotocoreEventMockAWS(BaseMockAWS):
@ -319,7 +308,9 @@ class BotocoreEventMockAWS(BaseMockAWS):
pattern = re.compile(key) pattern = re.compile(key)
botocore_stubber.register_response(method, pattern, value) botocore_stubber.register_response(method, pattern, value)
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'): if not hasattr(responses_mock, "_patcher") or not hasattr(
responses_mock._patcher, "target"
):
responses_mock.start() responses_mock.start()
for method in RESPONSES_METHODS: for method in RESPONSES_METHODS:
@ -335,6 +326,24 @@ class BotocoreEventMockAWS(BaseMockAWS):
match_querystring=False, match_querystring=False,
) )
) )
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile("https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
)
)
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile("https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
)
)
def disable_patching(self): def disable_patching(self):
botocore_stubber.enabled = False botocore_stubber.enabled = False
@ -350,9 +359,9 @@ MockAWS = BotocoreEventMockAWS
class ServerModeMockAWS(BaseMockAWS): class ServerModeMockAWS(BaseMockAWS):
def reset(self): def reset(self):
import requests import requests
requests.post("http://localhost:5000/moto-api/reset") requests.post("http://localhost:5000/moto-api/reset")
def enable_patching(self): def enable_patching(self):
@ -364,13 +373,13 @@ class ServerModeMockAWS(BaseMockAWS):
import mock import mock
def fake_boto3_client(*args, **kwargs): def fake_boto3_client(*args, **kwargs):
if 'endpoint_url' not in kwargs: if "endpoint_url" not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000" kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_client(*args, **kwargs) return real_boto3_client(*args, **kwargs)
def fake_boto3_resource(*args, **kwargs): def fake_boto3_resource(*args, **kwargs):
if 'endpoint_url' not in kwargs: if "endpoint_url" not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000" kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_resource(*args, **kwargs) return real_boto3_resource(*args, **kwargs)
def fake_httplib_send_output(self, message_body=None, *args, **kwargs): def fake_httplib_send_output(self, message_body=None, *args, **kwargs):
@ -378,7 +387,7 @@ class ServerModeMockAWS(BaseMockAWS):
bytes_buffer = [] bytes_buffer = []
for chunk in mixed_buffer: for chunk in mixed_buffer:
if isinstance(chunk, six.text_type): if isinstance(chunk, six.text_type):
bytes_buffer.append(chunk.encode('utf-8')) bytes_buffer.append(chunk.encode("utf-8"))
else: else:
bytes_buffer.append(chunk) bytes_buffer.append(chunk)
msg = b"\r\n".join(bytes_buffer) msg = b"\r\n".join(bytes_buffer)
@ -399,10 +408,12 @@ class ServerModeMockAWS(BaseMockAWS):
if message_body is not None: if message_body is not None:
self.send(message_body) self.send(message_body)
self._client_patcher = mock.patch('boto3.client', fake_boto3_client) self._client_patcher = mock.patch("boto3.client", fake_boto3_client)
self._resource_patcher = mock.patch('boto3.resource', fake_boto3_resource) self._resource_patcher = mock.patch("boto3.resource", fake_boto3_resource)
if six.PY2: if six.PY2:
self._httplib_patcher = mock.patch('httplib.HTTPConnection._send_output', fake_httplib_send_output) self._httplib_patcher = mock.patch(
"httplib.HTTPConnection._send_output", fake_httplib_send_output
)
self._client_patcher.start() self._client_patcher.start()
self._resource_patcher.start() self._resource_patcher.start()
@ -418,7 +429,6 @@ class ServerModeMockAWS(BaseMockAWS):
class Model(type): class Model(type):
def __new__(self, clsname, bases, namespace): def __new__(self, clsname, bases, namespace):
cls = super(Model, self).__new__(self, clsname, bases, namespace) cls = super(Model, self).__new__(self, clsname, bases, namespace)
cls.__models__ = {} cls.__models__ = {}
@ -433,9 +443,11 @@ class Model(type):
@staticmethod @staticmethod
def prop(model_name): def prop(model_name):
""" decorator to mark a class method as returning model values """ """ decorator to mark a class method as returning model values """
def dec(f): def dec(f):
f.__returns_model__ = model_name f.__returns_model__ = model_name
return f return f
return dec return dec
@ -445,7 +457,7 @@ model_data = defaultdict(dict)
class InstanceTrackerMeta(type): class InstanceTrackerMeta(type):
def __new__(meta, name, bases, dct): def __new__(meta, name, bases, dct):
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct) cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
if name == 'BaseModel': if name == "BaseModel":
return cls return cls
service = cls.__module__.split(".")[1] service = cls.__module__.split(".")[1]
@ -464,11 +476,14 @@ class BaseModel(object):
class BaseBackend(object): class BaseBackend(object):
def _reset_model_refs(self):
def reset(self): # Remove all references to the models stored
for service, models in model_data.items(): for service, models in model_data.items():
for model_name, model in models.items(): for model_name, model in models.items():
model.instances = [] model.instances = []
def reset(self):
self._reset_model_refs()
self.__dict__ = {} self.__dict__ = {}
self.__init__() self.__init__()
@ -476,8 +491,9 @@ class BaseBackend(object):
def _url_module(self): def _url_module(self):
backend_module = self.__class__.__module__ backend_module = self.__class__.__module__
backend_urls_module_name = backend_module.replace("models", "urls") backend_urls_module_name = backend_module.replace("models", "urls")
backend_urls_module = __import__(backend_urls_module_name, fromlist=[ backend_urls_module = __import__(
'url_bases', 'url_paths']) backend_urls_module_name, fromlist=["url_bases", "url_paths"]
)
return backend_urls_module return backend_urls_module
@property @property
@ -533,9 +549,9 @@ class BaseBackend(object):
def decorator(self, func=None): def decorator(self, func=None):
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
mocked_backend = ServerModeMockAWS({'global': self}) mocked_backend = ServerModeMockAWS({"global": self})
else: else:
mocked_backend = MockAWS({'global': self}) mocked_backend = MockAWS({"global": self})
if func: if func:
return mocked_backend(func) return mocked_backend(func)
@ -544,9 +560,100 @@ class BaseBackend(object):
def deprecated_decorator(self, func=None): def deprecated_decorator(self, func=None):
if func: if func:
return HttprettyMockAWS({'global': self})(func) return HttprettyMockAWS({"global": self})(func)
else: else:
return HttprettyMockAWS({'global': self}) return HttprettyMockAWS({"global": self})
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
# raise NotImplementedError()
class ConfigQueryModel(object):
def __init__(self, backends):
"""Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends
def list_config_service_resources(
self,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
):
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region.
This supports both aggregated and non-aggregated listing. The following notes the difference:
- Non-Aggregated Listing -
This only lists resources within a region. The way that this is implemented in moto is based on the region
for the resource backend.
You must set the `backend_region` to the region that the API request arrived from. resource_region can be set to `None`.
- Aggregated Listing -
This lists resources from all potential regional backends. For non-global resource types, this should collect a full
list of resources from all the backends, and then be able to filter from the resource region. This is because an
aggregator can aggregate resources from multiple regions. In moto, aggregated regions will *assume full aggregation
from all resources in all regions for a given resource type*.
The `backend_region` should be set to `None` for these queries, and the `resource_region` should optionally be set to
the `Filters` region parameter to filter out resources that reside in a specific region.
For aggregated listings, pagination logic should be set such that the next page can properly span all the region backends.
As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter
from there. It may be valuable to make this a concatenation of the region and resource name.
:param resource_region:
:param resource_ids:
:param resource_name:
:param limit:
:param next_token:
:param backend_region: The region for the backend to pull results from. Set to `None` if this is an aggregated query.
:return: This should return a list of Dicts that have the following fields:
[
{
'type': 'AWS::The AWS Config data type',
'name': 'The name of the resource',
'id': 'The ID of the resource',
'region': 'The region of the resource -- if global, then you may want to have the calling logic pass in the
aggregator region in for the resource region -- or just us-east-1 :P'
}
, ...
]
"""
raise NotImplementedError()
def get_config_resource(
self, resource_id, resource_name=None, backend_region=None, resource_region=None
):
"""For AWS Config. This will query the backend for the specific resource type configuration.
This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests
will call this function N times to fetch the N objects needing to be fetched.
- Non-Aggregated Fetching -
This only fetches a resource config within a region. The way that this is implemented in moto is based on the region
for the resource backend.
You must set the `backend_region` to the region that the API request arrived from. `resource_region` should be set to `None`.
- Aggregated Fetching -
This fetches resources from all potential regional backends. For non-global resource types, this should collect a full
list of resources from all the backends, and then be able to filter from the resource region. This is because an
aggregator can aggregate resources from multiple regions. In moto, aggregated regions will *assume full aggregation
from all resources in all regions for a given resource type*.
...
:param resource_id:
:param resource_name:
:param backend_region:
:param resource_region:
:return:
"""
raise NotImplementedError()
class base_decorator(object): class base_decorator(object):
@ -572,9 +679,9 @@ class deprecated_base_decorator(base_decorator):
class MotoAPIBackend(BaseBackend): class MotoAPIBackend(BaseBackend):
def reset(self): def reset(self):
from moto.backends import BACKENDS from moto.backends import BACKENDS
for name, backends in BACKENDS.items(): for name, backends in BACKENDS.items():
if name == "moto_api": if name == "moto_api":
continue continue

View File

@ -1,13 +1,17 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import functools
from collections import defaultdict from collections import defaultdict
import datetime import datetime
import json import json
import logging import logging
import re import re
import io import io
import requests
import pytz import pytz
from moto.core.access_control import IAMRequest, S3IAMRequest
from moto.core.exceptions import DryRunClientError from moto.core.exceptions import DryRunClientError
from jinja2 import Environment, DictLoader, TemplateNotFound from jinja2 import Environment, DictLoader, TemplateNotFound
@ -22,7 +26,7 @@ from werkzeug.exceptions import HTTPException
import boto3 import boto3
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core.utils import camelcase_to_underscores, method_names_from_class from moto.core.utils import camelcase_to_underscores, method_names_from_class
from moto import settings
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -36,7 +40,7 @@ def _decode_dict(d):
newkey = [] newkey = []
for k in key: for k in key:
if isinstance(k, six.binary_type): if isinstance(k, six.binary_type):
newkey.append(k.decode('utf-8')) newkey.append(k.decode("utf-8"))
else: else:
newkey.append(k) newkey.append(k)
else: else:
@ -48,7 +52,7 @@ def _decode_dict(d):
newvalue = [] newvalue = []
for v in value: for v in value:
if isinstance(v, six.binary_type): if isinstance(v, six.binary_type):
newvalue.append(v.decode('utf-8')) newvalue.append(v.decode("utf-8"))
else: else:
newvalue.append(v) newvalue.append(v)
else: else:
@ -79,12 +83,15 @@ class DynamicDictLoader(DictLoader):
class _TemplateEnvironmentMixin(object): class _TemplateEnvironmentMixin(object):
LEFT_PATTERN = re.compile(r"[\s\n]+<")
RIGHT_PATTERN = re.compile(r">[\s\n]+")
def __init__(self): def __init__(self):
super(_TemplateEnvironmentMixin, self).__init__() super(_TemplateEnvironmentMixin, self).__init__()
self.loader = DynamicDictLoader({}) self.loader = DynamicDictLoader({})
self.environment = Environment( self.environment = Environment(
loader=self.loader, autoescape=self.should_autoescape) loader=self.loader, autoescape=self.should_autoescape
)
@property @property
def should_autoescape(self): def should_autoescape(self):
@ -97,19 +104,92 @@ class _TemplateEnvironmentMixin(object):
def response_template(self, source): def response_template(self, source):
template_id = id(source) template_id = id(source)
if not self.contains_template(template_id): if not self.contains_template(template_id):
self.loader.update({template_id: source}) collapsed = re.sub(
self.environment = Environment(loader=self.loader, autoescape=self.should_autoescape, trim_blocks=True, self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source)
lstrip_blocks=True) )
self.loader.update({template_id: collapsed})
self.environment = Environment(
loader=self.loader,
autoescape=self.should_autoescape,
trim_blocks=True,
lstrip_blocks=True,
)
return self.environment.get_template(template_id) return self.environment.get_template(template_id)
class BaseResponse(_TemplateEnvironmentMixin): class ActionAuthenticatorMixin(object):
default_region = 'us-east-1' request_count = 0
def _authenticate_and_authorize_action(self, iam_request_cls):
if (
ActionAuthenticatorMixin.request_count
>= settings.INITIAL_NO_AUTH_ACTION_COUNT
):
iam_request = iam_request_cls(
method=self.method, path=self.path, data=self.data, headers=self.headers
)
iam_request.check_signature()
iam_request.check_action_permitted()
else:
ActionAuthenticatorMixin.request_count += 1
def _authenticate_and_authorize_normal_action(self):
self._authenticate_and_authorize_action(IAMRequest)
def _authenticate_and_authorize_s3_action(self):
self._authenticate_and_authorize_action(S3IAMRequest)
@staticmethod
def set_initial_no_auth_action_count(initial_no_auth_action_count):
def decorator(function):
def wrapper(*args, **kwargs):
if settings.TEST_SERVER_MODE:
response = requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(initial_no_auth_action_count).encode(),
)
original_initial_no_auth_action_count = response.json()[
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT"
]
else:
original_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
original_request_count = ActionAuthenticatorMixin.request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count
ActionAuthenticatorMixin.request_count = 0
try:
result = function(*args, **kwargs)
finally:
if settings.TEST_SERVER_MODE:
requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(original_initial_no_auth_action_count).encode(),
)
else:
ActionAuthenticatorMixin.request_count = original_request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = (
original_initial_no_auth_action_count
)
return result
functools.update_wrapper(wrapper, function)
wrapper.__wrapped__ = function
return wrapper
return decorator
class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = "us-east-1"
# to extract region, use [^.] # to extract region, use [^.]
region_regex = re.compile(r'\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com') region_regex = re.compile(r"\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com")
param_list_regex = re.compile(r'(.*)\.(\d+)\.') param_list_regex = re.compile(r"(.*)\.(\d+)\.")
access_key_regex = re.compile(r'AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]') access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
)
aws_service_spec = None aws_service_spec = None
@classmethod @classmethod
@ -118,7 +198,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
def setup_class(self, request, full_url, headers): def setup_class(self, request, full_url, headers):
querystring = {} querystring = {}
if hasattr(request, 'body'): if hasattr(request, "body"):
# Boto # Boto
self.body = request.body self.body = request.body
else: else:
@ -131,51 +211,66 @@ class BaseResponse(_TemplateEnvironmentMixin):
querystring = {} querystring = {}
for key, value in request.form.items(): for key, value in request.form.items():
querystring[key] = [value, ] querystring[key] = [value]
raw_body = self.body raw_body = self.body
if isinstance(self.body, six.binary_type): if isinstance(self.body, six.binary_type):
self.body = self.body.decode('utf-8') self.body = self.body.decode("utf-8")
if not querystring: if not querystring:
querystring.update( querystring.update(
parse_qs(urlparse(full_url).query, keep_blank_values=True)) parse_qs(urlparse(full_url).query, keep_blank_values=True)
)
if not querystring: if not querystring:
if 'json' in request.headers.get('content-type', []) and self.aws_service_spec: if (
"json" in request.headers.get("content-type", [])
and self.aws_service_spec
):
decoded = json.loads(self.body) decoded = json.loads(self.body)
target = request.headers.get( target = request.headers.get("x-amz-target") or request.headers.get(
'x-amz-target') or request.headers.get('X-Amz-Target') "X-Amz-Target"
service, method = target.split('.') )
service, method = target.split(".")
input_spec = self.aws_service_spec.input_spec(method) input_spec = self.aws_service_spec.input_spec(method)
flat = flatten_json_request_body('', decoded, input_spec) flat = flatten_json_request_body("", decoded, input_spec)
for key, value in flat.items(): for key, value in flat.items():
querystring[key] = [value] querystring[key] = [value]
elif self.body: elif self.body:
querystring.update(parse_qs(raw_body, keep_blank_values=True)) try:
querystring.update(parse_qs(raw_body, keep_blank_values=True))
except UnicodeEncodeError:
pass # ignore encoding errors, as the body may not contain a legitimate querystring
if not querystring: if not querystring:
querystring.update(headers) querystring.update(headers)
querystring = _decode_dict(querystring) try:
querystring = _decode_dict(querystring)
except UnicodeDecodeError:
pass # ignore decoding errors, as the body may not contain a legitimate querystring
self.uri = full_url self.uri = full_url
self.path = urlparse(full_url).path self.path = urlparse(full_url).path
self.querystring = querystring self.querystring = querystring
self.data = querystring
self.method = request.method self.method = request.method
self.region = self.get_region_from_url(request, full_url) self.region = self.get_region_from_url(request, full_url)
self.uri_match = None self.uri_match = None
self.headers = request.headers self.headers = request.headers
if 'host' not in self.headers: if "host" not in self.headers:
self.headers['host'] = urlparse(full_url).netloc self.headers["host"] = urlparse(full_url).netloc
self.response_headers = {"server": "amazon.com"} self.response_headers = {"server": "amazon.com"}
def get_region_from_url(self, request, full_url): def get_region_from_url(self, request, full_url):
match = self.region_regex.search(full_url) match = self.region_regex.search(full_url)
if match: if match:
region = match.group(1) region = match.group(1)
elif 'Authorization' in request.headers and 'AWS4' in request.headers['Authorization']: elif (
region = request.headers['Authorization'].split(",")[ "Authorization" in request.headers
0].split("/")[2] and "AWS4" in request.headers["Authorization"]
):
region = request.headers["Authorization"].split(",")[0].split("/")[2]
else: else:
region = self.default_region region = self.default_region
return region return region
@ -184,16 +279,16 @@ class BaseResponse(_TemplateEnvironmentMixin):
""" """
Returns the access key id used in this request as the current user id Returns the access key id used in this request as the current user id
""" """
if 'Authorization' in self.headers: if "Authorization" in self.headers:
match = self.access_key_regex.search(self.headers['Authorization']) match = self.access_key_regex.search(self.headers["Authorization"])
if match: if match:
return match.group(1) return match.group(1)
if self.querystring.get('AWSAccessKeyId'): if self.querystring.get("AWSAccessKeyId"):
return self.querystring.get('AWSAccessKeyId') return self.querystring.get("AWSAccessKeyId")
else: else:
# Should we raise an unauthorized exception instead? # Should we raise an unauthorized exception instead?
return '111122223333' return "111122223333"
def _dispatch(self, request, full_url, headers): def _dispatch(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -208,17 +303,22 @@ class BaseResponse(_TemplateEnvironmentMixin):
-> '^/cars/.*/drivers/.*/drive$' -> '^/cars/.*/drivers/.*/drive$'
""" """
def _convert(elem, is_last):
if not re.match('^{.*}$', elem):
return elem
name = elem.replace('{', '').replace('}', '')
if is_last:
return '(?P<%s>[^/]*)' % name
return '(?P<%s>.*)' % name
elems = uri.split('/') def _convert(elem, is_last):
if not re.match("^{.*}$", elem):
return elem
name = elem.replace("{", "").replace("}", "").replace("+", "")
if is_last:
return "(?P<%s>[^/]*)" % name
return "(?P<%s>.*)" % name
elems = uri.split("/")
num_elems = len(elems) num_elems = len(elems)
regexp = '^{}$'.format('/'.join([_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)])) regexp = "^{}$".format(
"/".join(
[_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)]
)
)
return regexp return regexp
def _get_action_from_method_and_request_uri(self, method, request_uri): def _get_action_from_method_and_request_uri(self, method, request_uri):
@ -229,19 +329,19 @@ class BaseResponse(_TemplateEnvironmentMixin):
# service response class should have 'SERVICE_NAME' class member, # service response class should have 'SERVICE_NAME' class member,
# if you want to get action from method and url # if you want to get action from method and url
if not hasattr(self, 'SERVICE_NAME'): if not hasattr(self, "SERVICE_NAME"):
return None return None
service = self.SERVICE_NAME service = self.SERVICE_NAME
conn = boto3.client(service, region_name=self.region) conn = boto3.client(service, region_name=self.region)
# make cache if it does not exist yet # make cache if it does not exist yet
if not hasattr(self, 'method_urls'): if not hasattr(self, "method_urls"):
self.method_urls = defaultdict(lambda: defaultdict(str)) self.method_urls = defaultdict(lambda: defaultdict(str))
op_names = conn._service_model.operation_names op_names = conn._service_model.operation_names
for op_name in op_names: for op_name in op_names:
op_model = conn._service_model.operation_model(op_name) op_model = conn._service_model.operation_model(op_name)
_method = op_model.http['method'] _method = op_model.http["method"]
uri_regexp = self.uri_to_regexp(op_model.http['requestUri']) uri_regexp = self.uri_to_regexp(op_model.http["requestUri"])
self.method_urls[_method][uri_regexp] = op_model.name self.method_urls[_method][uri_regexp] = op_model.name
regexp_and_names = self.method_urls[method] regexp_and_names = self.method_urls[method]
for regexp, name in regexp_and_names.items(): for regexp, name in regexp_and_names.items():
@ -252,11 +352,10 @@ class BaseResponse(_TemplateEnvironmentMixin):
return None return None
def _get_action(self): def _get_action(self):
action = self.querystring.get('Action', [""])[0] action = self.querystring.get("Action", [""])[0]
if not action: # Some services use a header for the action if not action: # Some services use a header for the action
# Headers are case-insensitive. Probably a better way to do this. # Headers are case-insensitive. Probably a better way to do this.
match = self.headers.get( match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target")
'x-amz-target') or self.headers.get('X-Amz-Target')
if match: if match:
action = match.split(".")[-1] action = match.split(".")[-1]
# get action from method and uri # get action from method and uri
@ -266,6 +365,13 @@ class BaseResponse(_TemplateEnvironmentMixin):
def call_action(self): def call_action(self):
headers = self.response_headers headers = self.response_headers
try:
self._authenticate_and_authorize_normal_action()
except HTTPException as http_error:
response = http_error.description, dict(status=http_error.code)
return self._send_response(headers, response)
action = camelcase_to_underscores(self._get_action()) action = camelcase_to_underscores(self._get_action())
method_names = method_names_from_class(self.__class__) method_names = method_names_from_class(self.__class__)
if action in method_names: if action in method_names:
@ -278,22 +384,27 @@ class BaseResponse(_TemplateEnvironmentMixin):
if isinstance(response, six.string_types): if isinstance(response, six.string_types):
return 200, headers, response return 200, headers, response
else: else:
if len(response) == 2: return self._send_response(headers, response)
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
return status, headers, body
if not action: if not action:
return 404, headers, '' return 404, headers, ""
raise NotImplementedError( raise NotImplementedError(
"The {0} action has not been implemented".format(action)) "The {0} action has not been implemented".format(action)
)
@staticmethod
def _send_response(headers, response):
if len(response) == 2:
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get("status", 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers["status"] = str(headers["status"])
return status, headers, body
def _get_param(self, param_name, if_none=None): def _get_param(self, param_name, if_none=None):
val = self.querystring.get(param_name) val = self.querystring.get(param_name)
@ -326,9 +437,9 @@ class BaseResponse(_TemplateEnvironmentMixin):
def _get_bool_param(self, param_name, if_none=None): def _get_bool_param(self, param_name, if_none=None):
val = self._get_param(param_name) val = self._get_param(param_name)
if val is not None: if val is not None:
if val.lower() == 'true': if val.lower() == "true":
return True return True
elif val.lower() == 'false': elif val.lower() == "false":
return False return False
return if_none return if_none
@ -346,11 +457,16 @@ class BaseResponse(_TemplateEnvironmentMixin):
if is_tracked(name) or not name.startswith(param_prefix): if is_tracked(name) or not name.startswith(param_prefix):
continue continue
if len(name) > len(param_prefix) and \ if len(name) > len(param_prefix) and not name[
not name[len(param_prefix):].startswith('.'): len(param_prefix) :
].startswith("."):
continue continue
match = self.param_list_regex.search(name[len(param_prefix):]) if len(name) > len(param_prefix) else None match = (
self.param_list_regex.search(name[len(param_prefix) :])
if len(name) > len(param_prefix)
else None
)
if match: if match:
prefix = param_prefix + match.group(1) prefix = param_prefix + match.group(1)
value = self._get_multi_param(prefix) value = self._get_multi_param(prefix)
@ -365,7 +481,10 @@ class BaseResponse(_TemplateEnvironmentMixin):
if len(value_dict) > 1: if len(value_dict) > 1:
# strip off period prefix # strip off period prefix
value_dict = {name[len(param_prefix) + 1:]: value for name, value in value_dict.items()} value_dict = {
name[len(param_prefix) + 1 :]: value
for name, value in value_dict.items()
}
else: else:
value_dict = list(value_dict.values())[0] value_dict = list(value_dict.values())[0]
@ -384,7 +503,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
index = 1 index = 1
while True: while True:
value_dict = self._get_multi_param_helper(prefix + str(index)) value_dict = self._get_multi_param_helper(prefix + str(index))
if not value_dict: if not value_dict and value_dict != "":
break break
values.append(value_dict) values.append(value_dict)
@ -409,8 +528,9 @@ class BaseResponse(_TemplateEnvironmentMixin):
params = {} params = {}
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith(param_prefix): if key.startswith(param_prefix):
params[camelcase_to_underscores( params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[
key.replace(param_prefix, ""))] = value[0] 0
]
return params return params
def _get_list_prefix(self, param_prefix): def _get_list_prefix(self, param_prefix):
@ -443,19 +563,20 @@ class BaseResponse(_TemplateEnvironmentMixin):
new_items = {} new_items = {}
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith(index_prefix): if key.startswith(index_prefix):
new_items[camelcase_to_underscores( new_items[
key.replace(index_prefix, ""))] = value[0] camelcase_to_underscores(key.replace(index_prefix, ""))
] = value[0]
if not new_items: if not new_items:
break break
results.append(new_items) results.append(new_items)
param_index += 1 param_index += 1
return results return results
def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'): def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"):
results = {} results = {}
param_index = 1 param_index = 1
while 1: while 1:
index_prefix = '{0}.{1}.'.format(param_prefix, param_index) index_prefix = "{0}.{1}.".format(param_prefix, param_index)
k, v = None, None k, v = None, None
for key, value in self.querystring.items(): for key, value in self.querystring.items():
@ -482,8 +603,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
param_index = 1 param_index = 1
while True: while True:
key_name = 'tag.{0}._key'.format(param_index) key_name = "tag.{0}._key".format(param_index)
value_name = 'tag.{0}._value'.format(param_index) value_name = "tag.{0}._value".format(param_index)
try: try:
results[resource_type][tag[key_name]] = tag[value_name] results[resource_type][tag[key_name]] = tag[value_name]
@ -493,7 +614,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
return results return results
def _get_object_map(self, prefix, name='Name', value='Value'): def _get_object_map(self, prefix, name="Name", value="Value"):
""" """
Given a query dict like Given a query dict like
{ {
@ -521,15 +642,14 @@ class BaseResponse(_TemplateEnvironmentMixin):
index = 1 index = 1
while True: while True:
# Loop through looking for keys representing object name # Loop through looking for keys representing object name
name_key = '{0}.{1}.{2}'.format(prefix, index, name) name_key = "{0}.{1}.{2}".format(prefix, index, name)
obj_name = self.querystring.get(name_key) obj_name = self.querystring.get(name_key)
if not obj_name: if not obj_name:
# Found all keys # Found all keys
break break
obj = {} obj = {}
value_key_prefix = '{0}.{1}.{2}.'.format( value_key_prefix = "{0}.{1}.{2}.".format(prefix, index, value)
prefix, index, value)
for k, v in self.querystring.items(): for k, v in self.querystring.items():
if k.startswith(value_key_prefix): if k.startswith(value_key_prefix):
_, value_key = k.split(value_key_prefix, 1) _, value_key = k.split(value_key_prefix, 1)
@ -543,25 +663,48 @@ class BaseResponse(_TemplateEnvironmentMixin):
@property @property
def request_json(self): def request_json(self):
return 'JSON' in self.querystring.get('ContentType', []) return "JSON" in self.querystring.get("ContentType", [])
def is_not_dryrun(self, action): def is_not_dryrun(self, action):
if 'true' in self.querystring.get('DryRun', ['false']): if "true" in self.querystring.get("DryRun", ["false"]):
message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action message = (
raise DryRunClientError( "An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set"
error_type="DryRunOperation", message=message) % action
)
raise DryRunClientError(error_type="DryRunOperation", message=message)
return True return True
class MotoAPIResponse(BaseResponse): class MotoAPIResponse(BaseResponse):
def reset_response(self, request, full_url, headers): def reset_response(self, request, full_url, headers):
if request.method == "POST": if request.method == "POST":
from .models import moto_api_backend from .models import moto_api_backend
moto_api_backend.reset() moto_api_backend.reset()
return 200, {}, json.dumps({"status": "ok"}) return 200, {}, json.dumps({"status": "ok"})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"}) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"})
def reset_auth_response(self, request, full_url, headers):
if request.method == "POST":
previous_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode())
ActionAuthenticatorMixin.request_count = 0
return (
200,
{},
json.dumps(
{
"status": "ok",
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(
previous_initial_no_auth_action_count
),
}
),
)
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
def model_data(self, request, full_url, headers): def model_data(self, request, full_url, headers):
from moto.core.models import model_data from moto.core.models import model_data
@ -587,7 +730,8 @@ class MotoAPIResponse(BaseResponse):
def dashboard(self, request, full_url, headers): def dashboard(self, request, full_url, headers):
from flask import render_template from flask import render_template
return render_template('dashboard.html')
return render_template("dashboard.html")
class _RecursiveDictRef(object): class _RecursiveDictRef(object):
@ -598,7 +742,7 @@ class _RecursiveDictRef(object):
self.dic = {} self.dic = {}
def __repr__(self): def __repr__(self):
return '{!r}'.format(self.dic) return "{!r}".format(self.dic)
def __getattr__(self, key): def __getattr__(self, key):
return self.dic.__getattr__(key) return self.dic.__getattr__(key)
@ -622,21 +766,21 @@ class AWSServiceSpec(object):
""" """
def __init__(self, path): def __init__(self, path):
self.path = resource_filename('botocore', path) self.path = resource_filename("botocore", path)
with io.open(self.path, 'r', encoding='utf-8') as f: with io.open(self.path, "r", encoding="utf-8") as f:
spec = json.load(f) spec = json.load(f)
self.metadata = spec['metadata'] self.metadata = spec["metadata"]
self.operations = spec['operations'] self.operations = spec["operations"]
self.shapes = spec['shapes'] self.shapes = spec["shapes"]
def input_spec(self, operation): def input_spec(self, operation):
try: try:
op = self.operations[operation] op = self.operations[operation]
except KeyError: except KeyError:
raise ValueError('Invalid operation: {}'.format(operation)) raise ValueError("Invalid operation: {}".format(operation))
if 'input' not in op: if "input" not in op:
return {} return {}
shape = self.shapes[op['input']['shape']] shape = self.shapes[op["input"]["shape"]]
return self._expand(shape) return self._expand(shape)
def output_spec(self, operation): def output_spec(self, operation):
@ -650,129 +794,133 @@ class AWSServiceSpec(object):
try: try:
op = self.operations[operation] op = self.operations[operation]
except KeyError: except KeyError:
raise ValueError('Invalid operation: {}'.format(operation)) raise ValueError("Invalid operation: {}".format(operation))
if 'output' not in op: if "output" not in op:
return {} return {}
shape = self.shapes[op['output']['shape']] shape = self.shapes[op["output"]["shape"]]
return self._expand(shape) return self._expand(shape)
def _expand(self, shape): def _expand(self, shape):
def expand(dic, seen=None): def expand(dic, seen=None):
seen = seen or {} seen = seen or {}
if dic['type'] == 'structure': if dic["type"] == "structure":
nodes = {} nodes = {}
for k, v in dic['members'].items(): for k, v in dic["members"].items():
seen_till_here = dict(seen) seen_till_here = dict(seen)
if k in seen_till_here: if k in seen_till_here:
nodes[k] = seen_till_here[k] nodes[k] = seen_till_here[k]
continue continue
seen_till_here[k] = _RecursiveDictRef() seen_till_here[k] = _RecursiveDictRef()
nodes[k] = expand(self.shapes[v['shape']], seen_till_here) nodes[k] = expand(self.shapes[v["shape"]], seen_till_here)
seen_till_here[k].set_reference(k, nodes[k]) seen_till_here[k].set_reference(k, nodes[k])
nodes['type'] = 'structure' nodes["type"] = "structure"
return nodes return nodes
elif dic['type'] == 'list': elif dic["type"] == "list":
seen_till_here = dict(seen) seen_till_here = dict(seen)
shape = dic['member']['shape'] shape = dic["member"]["shape"]
if shape in seen_till_here: if shape in seen_till_here:
return seen_till_here[shape] return seen_till_here[shape]
seen_till_here[shape] = _RecursiveDictRef() seen_till_here[shape] = _RecursiveDictRef()
expanded = expand(self.shapes[shape], seen_till_here) expanded = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, expanded) seen_till_here[shape].set_reference(shape, expanded)
return {'type': 'list', 'member': expanded} return {"type": "list", "member": expanded}
elif dic['type'] == 'map': elif dic["type"] == "map":
seen_till_here = dict(seen) seen_till_here = dict(seen)
node = {'type': 'map'} node = {"type": "map"}
if 'shape' in dic['key']: if "shape" in dic["key"]:
shape = dic['key']['shape'] shape = dic["key"]["shape"]
seen_till_here[shape] = _RecursiveDictRef() seen_till_here[shape] = _RecursiveDictRef()
node['key'] = expand(self.shapes[shape], seen_till_here) node["key"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['key']) seen_till_here[shape].set_reference(shape, node["key"])
else: else:
node['key'] = dic['key']['type'] node["key"] = dic["key"]["type"]
if 'shape' in dic['value']: if "shape" in dic["value"]:
shape = dic['value']['shape'] shape = dic["value"]["shape"]
seen_till_here[shape] = _RecursiveDictRef() seen_till_here[shape] = _RecursiveDictRef()
node['value'] = expand(self.shapes[shape], seen_till_here) node["value"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['value']) seen_till_here[shape].set_reference(shape, node["value"])
else: else:
node['value'] = dic['value']['type'] node["value"] = dic["value"]["type"]
return node return node
else: else:
return {'type': dic['type']} return {"type": dic["type"]}
return expand(shape) return expand(shape)
def to_str(value, spec): def to_str(value, spec):
vtype = spec['type'] vtype = spec["type"]
if vtype == 'boolean': if vtype == "boolean":
return 'true' if value else 'false' return "true" if value else "false"
elif vtype == 'integer': elif vtype == "integer":
return str(value) return str(value)
elif vtype == 'float': elif vtype == "float":
return str(value) return str(value)
elif vtype == 'double': elif vtype == "double":
return str(value) return str(value)
elif vtype == 'timestamp': elif vtype == "timestamp":
return datetime.datetime.utcfromtimestamp( return (
value).replace(tzinfo=pytz.utc).isoformat() datetime.datetime.utcfromtimestamp(value)
elif vtype == 'string': .replace(tzinfo=pytz.utc)
.isoformat()
)
elif vtype == "string":
return str(value) return str(value)
elif value is None: elif value is None:
return 'null' return "null"
else: else:
raise TypeError('Unknown type {}'.format(vtype)) raise TypeError("Unknown type {}".format(vtype))
def from_str(value, spec): def from_str(value, spec):
vtype = spec['type'] vtype = spec["type"]
if vtype == 'boolean': if vtype == "boolean":
return True if value == 'true' else False return True if value == "true" else False
elif vtype == 'integer': elif vtype == "integer":
return int(value) return int(value)
elif vtype == 'float': elif vtype == "float":
return float(value) return float(value)
elif vtype == 'double': elif vtype == "double":
return float(value) return float(value)
elif vtype == 'timestamp': elif vtype == "timestamp":
return value return value
elif vtype == 'string': elif vtype == "string":
return value return value
raise TypeError('Unknown type {}'.format(vtype)) raise TypeError("Unknown type {}".format(vtype))
def flatten_json_request_body(prefix, dict_body, spec): def flatten_json_request_body(prefix, dict_body, spec):
"""Convert a JSON request body into query params.""" """Convert a JSON request body into query params."""
if len(spec) == 1 and 'type' in spec: if len(spec) == 1 and "type" in spec:
return {prefix: to_str(dict_body, spec)} return {prefix: to_str(dict_body, spec)}
flat = {} flat = {}
for key, value in dict_body.items(): for key, value in dict_body.items():
node_type = spec[key]['type'] node_type = spec[key]["type"]
if node_type == 'list': if node_type == "list":
for idx, v in enumerate(value, 1): for idx, v in enumerate(value, 1):
pref = key + '.member.' + str(idx) pref = key + ".member." + str(idx)
flat.update(flatten_json_request_body( flat.update(flatten_json_request_body(pref, v, spec[key]["member"]))
pref, v, spec[key]['member'])) elif node_type == "map":
elif node_type == 'map':
for idx, (k, v) in enumerate(value.items(), 1): for idx, (k, v) in enumerate(value.items(), 1):
pref = key + '.entry.' + str(idx) pref = key + ".entry." + str(idx)
flat.update(flatten_json_request_body( flat.update(
pref + '.key', k, spec[key]['key'])) flatten_json_request_body(pref + ".key", k, spec[key]["key"])
flat.update(flatten_json_request_body( )
pref + '.value', v, spec[key]['value'])) flat.update(
flatten_json_request_body(pref + ".value", v, spec[key]["value"])
)
else: else:
flat.update(flatten_json_request_body(key, value, spec[key])) flat.update(flatten_json_request_body(key, value, spec[key]))
if prefix: if prefix:
prefix = prefix + '.' prefix = prefix + "."
return dict((prefix + k, v) for k, v in flat.items()) return dict((prefix + k, v) for k, v in flat.items())
@ -795,41 +943,40 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
od = OrderedDict() od = OrderedDict()
for k, v in value.items(): for k, v in value.items():
if k.startswith('@'): if k.startswith("@"):
continue continue
if k not in spec: if k not in spec:
# this can happen when with an older version of # this can happen when with an older version of
# botocore for which the node in XML template is not # botocore for which the node in XML template is not
# defined in service spec. # defined in service spec.
log.warning( log.warning("Field %s is not defined by the botocore version in use", k)
'Field %s is not defined by the botocore version in use', k)
continue continue
if spec[k]['type'] == 'list': if spec[k]["type"] == "list":
if v is None: if v is None:
od[k] = [] od[k] = []
elif len(spec[k]['member']) == 1: elif len(spec[k]["member"]) == 1:
if isinstance(v['member'], list): if isinstance(v["member"], list):
od[k] = transform(v['member'], spec[k]['member']) od[k] = transform(v["member"], spec[k]["member"])
else: else:
od[k] = [transform(v['member'], spec[k]['member'])] od[k] = [transform(v["member"], spec[k]["member"])]
elif isinstance(v['member'], list): elif isinstance(v["member"], list):
od[k] = [transform(o, spec[k]['member']) od[k] = [transform(o, spec[k]["member"]) for o in v["member"]]
for o in v['member']] elif isinstance(v["member"], OrderedDict):
elif isinstance(v['member'], OrderedDict): od[k] = [transform(v["member"], spec[k]["member"])]
od[k] = [transform(v['member'], spec[k]['member'])]
else: else:
raise ValueError('Malformatted input') raise ValueError("Malformatted input")
elif spec[k]['type'] == 'map': elif spec[k]["type"] == "map":
if v is None: if v is None:
od[k] = {} od[k] = {}
else: else:
items = ([v['entry']] if not isinstance(v['entry'], list) else items = (
v['entry']) [v["entry"]] if not isinstance(v["entry"], list) else v["entry"]
)
for item in items: for item in items:
key = from_str(item['key'], spec[k]['key']) key = from_str(item["key"], spec[k]["key"])
val = from_str(item['value'], spec[k]['value']) val = from_str(item["value"], spec[k]["value"])
if k not in od: if k not in od:
od[k] = {} od[k] = {}
od[k][key] = val od[k][key] = val
@ -843,7 +990,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
dic = xmltodict.parse(xml) dic = xmltodict.parse(xml)
output_spec = service_spec.output_spec(operation) output_spec = service_spec.output_spec(operation)
try: try:
for k in (result_node or (operation + 'Response', operation + 'Result')): for k in result_node or (operation + "Response", operation + "Result"):
dic = dic[k] dic = dic[k]
except KeyError: except KeyError:
return None return None

View File

@ -1,14 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import MotoAPIResponse from .responses import MotoAPIResponse
url_bases = [ url_bases = ["https?://motoapi.amazonaws.com"]
"https?://motoapi.amazonaws.com"
]
response_instance = MotoAPIResponse() response_instance = MotoAPIResponse()
url_paths = { url_paths = {
'{0}/moto-api/$': response_instance.dashboard, "{0}/moto-api/$": response_instance.dashboard,
'{0}/moto-api/data.json': response_instance.model_data, "{0}/moto-api/data.json": response_instance.model_data,
'{0}/moto-api/reset': response_instance.reset_response, "{0}/moto-api/reset": response_instance.reset_response,
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
} }

View File

@ -8,6 +8,7 @@ import random
import re import re
import six import six
import string import string
from botocore.exceptions import ClientError
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
@ -15,9 +16,9 @@ REQUEST_ID_LONG = string.digits + string.ascii_uppercase
def camelcase_to_underscores(argument): def camelcase_to_underscores(argument):
''' Converts a camelcase param like theNewAttribute to the equivalent """ Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute''' python underscore variable like the_new_attribute"""
result = '' result = ""
prev_char_title = True prev_char_title = True
if not argument: if not argument:
return argument return argument
@ -41,18 +42,18 @@ def camelcase_to_underscores(argument):
def underscores_to_camelcase(argument): def underscores_to_camelcase(argument):
''' Converts a camelcase param like the_new_attribute to the equivalent """ Converts a camelcase param like the_new_attribute to the equivalent
camelcase version like theNewAttribute. Note that the first letter is camelcase version like theNewAttribute. Note that the first letter is
NOT capitalized by this function ''' NOT capitalized by this function """
result = '' result = ""
previous_was_underscore = False previous_was_underscore = False
for char in argument: for char in argument:
if char != '_': if char != "_":
if previous_was_underscore: if previous_was_underscore:
result += char.upper() result += char.upper()
else: else:
result += char result += char
previous_was_underscore = char == '_' previous_was_underscore = char == "_"
return result return result
@ -69,12 +70,18 @@ def method_names_from_class(clazz):
def get_random_hex(length=8): def get_random_hex(length=8):
chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f'] chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"]
return ''.join(six.text_type(random.choice(chars)) for x in range(length)) return "".join(six.text_type(random.choice(chars)) for x in range(length))
def get_random_message_id(): def get_random_message_id():
return '{0}-{1}-{2}-{3}-{4}'.format(get_random_hex(8), get_random_hex(4), get_random_hex(4), get_random_hex(4), get_random_hex(12)) return "{0}-{1}-{2}-{3}-{4}".format(
get_random_hex(8),
get_random_hex(4),
get_random_hex(4),
get_random_hex(4),
get_random_hex(12),
)
def convert_regex_to_flask_path(url_path): def convert_regex_to_flask_path(url_path):
@ -97,7 +104,6 @@ def convert_regex_to_flask_path(url_path):
class convert_httpretty_response(object): class convert_httpretty_response(object):
def __init__(self, callback): def __init__(self, callback):
self.callback = callback self.callback = callback
@ -114,13 +120,12 @@ class convert_httpretty_response(object):
def __call__(self, request, url, headers, **kwargs): def __call__(self, request, url, headers, **kwargs):
result = self.callback(request, url, headers) result = self.callback(request, url, headers)
status, headers, response = result status, headers, response = result
if 'server' not in headers: if "server" not in headers:
headers["server"] = "amazon.com" headers["server"] = "amazon.com"
return status, headers, response return status, headers, response
class convert_flask_to_httpretty_response(object): class convert_flask_to_httpretty_response(object):
def __init__(self, callback): def __init__(self, callback):
self.callback = callback self.callback = callback
@ -137,7 +142,10 @@ class convert_flask_to_httpretty_response(object):
def __call__(self, args=None, **kwargs): def __call__(self, args=None, **kwargs):
from flask import request, Response from flask import request, Response
result = self.callback(request, request.url, {}) try:
result = self.callback(request, request.url, {})
except ClientError as exc:
result = 400, {}, exc.response["Error"]["Message"]
# result is a status, headers, response tuple # result is a status, headers, response tuple
if len(result) == 3: if len(result) == 3:
status, headers, content = result status, headers, content = result
@ -145,13 +153,12 @@ class convert_flask_to_httpretty_response(object):
status, headers, content = 200, {}, result status, headers, content = 200, {}, result
response = Response(response=content, status=status, headers=headers) response = Response(response=content, status=status, headers=headers)
if request.method == "HEAD" and 'content-length' in headers: if request.method == "HEAD" and "content-length" in headers:
response.headers['Content-Length'] = headers['content-length'] response.headers["Content-Length"] = headers["content-length"]
return response return response
class convert_flask_to_responses_response(object): class convert_flask_to_responses_response(object):
def __init__(self, callback): def __init__(self, callback):
self.callback = callback self.callback = callback
@ -176,14 +183,14 @@ class convert_flask_to_responses_response(object):
def iso_8601_datetime_with_milliseconds(datetime): def iso_8601_datetime_with_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + 'Z' return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def iso_8601_datetime_without_milliseconds(datetime): def iso_8601_datetime_without_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z' return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
RFC1123 = '%a, %d %b %Y %H:%M:%S GMT' RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(datetime): def rfc_1123_datetime(datetime):
@ -212,16 +219,16 @@ def gen_amz_crc32(response, headerdict=None):
crc = str(binascii.crc32(response)) crc = str(binascii.crc32(response))
if headerdict is not None and isinstance(headerdict, dict): if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amz-crc32': crc}) headerdict.update({"x-amz-crc32": crc})
return crc return crc
def gen_amzn_requestid_long(headerdict=None): def gen_amzn_requestid_long(headerdict=None):
req_id = ''.join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)]) req_id = "".join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
if headerdict is not None and isinstance(headerdict, dict): if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amzn-requestid': req_id}) headerdict.update({"x-amzn-requestid": req_id})
return req_id return req_id
@ -239,13 +246,13 @@ def amz_crc32(f):
else: else:
if len(response) == 2: if len(response) == 2:
body, new_headers = response body, new_headers = response
status = new_headers.get('status', 200) status = new_headers.get("status", 200)
else: else:
status, new_headers, body = response status, new_headers, body = response
headers.update(new_headers) headers.update(new_headers)
# Cast status to string # Cast status to string
if "status" in headers: if "status" in headers:
headers['status'] = str(headers['status']) headers["status"] = str(headers["status"])
try: try:
# Doesnt work on python2 for some odd unicode strings # Doesnt work on python2 for some odd unicode strings
@ -271,7 +278,7 @@ def amzn_request_id(f):
else: else:
if len(response) == 2: if len(response) == 2:
body, new_headers = response body, new_headers = response
status = new_headers.get('status', 200) status = new_headers.get("status", 200)
else: else:
status, new_headers, body = response status, new_headers, body = response
headers.update(new_headers) headers.update(new_headers)
@ -280,7 +287,7 @@ def amzn_request_id(f):
# Update request ID in XML # Update request ID in XML
try: try:
body = re.sub(r'(?<=<RequestId>).*(?=<\/RequestId>)', request_id, body) body = re.sub(r"(?<=<RequestId>).*(?=<\/RequestId>)", request_id, body)
except Exception: # Will just ignore if it cant work on bytes (which are str's on python2) except Exception: # Will just ignore if it cant work on bytes (which are str's on python2)
pass pass
@ -293,7 +300,31 @@ def path_url(url):
parsed_url = urlparse(url) parsed_url = urlparse(url)
path = parsed_url.path path = parsed_url.path
if not path: if not path:
path = '/' path = "/"
if parsed_url.query: if parsed_url.query:
path = path + '?' + parsed_url.query path = path + "?" + parsed_url.query
return path return path
def py2_strip_unicode_keys(blob):
"""For Python 2 Only -- this will convert unicode keys in nested Dicts, Lists, and Sets to standard strings."""
if type(blob) == unicode: # noqa
return str(blob)
elif type(blob) == dict:
for key in list(blob.keys()):
value = blob.pop(key)
blob[str(key)] = py2_strip_unicode_keys(value)
elif type(blob) == list:
for i in range(0, len(blob)):
blob[i] = py2_strip_unicode_keys(blob[i])
elif type(blob) == set:
new_set = set()
for value in blob:
new_set.add(py2_strip_unicode_keys(value))
blob = new_set
return blob

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import datapipeline_backends from .models import datapipeline_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
datapipeline_backend = datapipeline_backends['us-east-1'] datapipeline_backend = datapipeline_backends["us-east-1"]
mock_datapipeline = base_decorator(datapipeline_backends) mock_datapipeline = base_decorator(datapipeline_backends)
mock_datapipeline_deprecated = deprecated_base_decorator(datapipeline_backends) mock_datapipeline_deprecated = deprecated_base_decorator(datapipeline_backends)

View File

@ -1,92 +1,73 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime import datetime
import boto.datapipeline from boto3 import Session
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys
class PipelineObject(BaseModel): class PipelineObject(BaseModel):
def __init__(self, object_id, name, fields): def __init__(self, object_id, name, fields):
self.object_id = object_id self.object_id = object_id
self.name = name self.name = name
self.fields = fields self.fields = fields
def to_json(self): def to_json(self):
return { return {"fields": self.fields, "id": self.object_id, "name": self.name}
"fields": self.fields,
"id": self.object_id,
"name": self.name,
}
class Pipeline(BaseModel): class Pipeline(BaseModel):
def __init__(self, name, unique_id, **kwargs): def __init__(self, name, unique_id, **kwargs):
self.name = name self.name = name
self.unique_id = unique_id self.unique_id = unique_id
self.description = kwargs.get('description', '') self.description = kwargs.get("description", "")
self.pipeline_id = get_random_pipeline_id() self.pipeline_id = get_random_pipeline_id()
self.creation_time = datetime.datetime.utcnow() self.creation_time = datetime.datetime.utcnow()
self.objects = [] self.objects = []
self.status = "PENDING" self.status = "PENDING"
self.tags = kwargs.get('tags', []) self.tags = kwargs.get("tags", [])
@property @property
def physical_resource_id(self): def physical_resource_id(self):
return self.pipeline_id return self.pipeline_id
def to_meta_json(self): def to_meta_json(self):
return { return {"id": self.pipeline_id, "name": self.name}
"id": self.pipeline_id,
"name": self.name,
}
def to_json(self): def to_json(self):
return { return {
"description": self.description, "description": self.description,
"fields": [{ "fields": [
"key": "@pipelineState", {"key": "@pipelineState", "stringValue": self.status},
"stringValue": self.status, {"key": "description", "stringValue": self.description},
}, { {"key": "name", "stringValue": self.name},
"key": "description", {
"stringValue": self.description "key": "@creationTime",
}, { "stringValue": datetime.datetime.strftime(
"key": "name", self.creation_time, "%Y-%m-%dT%H-%M-%S"
"stringValue": self.name ),
}, { },
"key": "@creationTime", {"key": "@id", "stringValue": self.pipeline_id},
"stringValue": datetime.datetime.strftime(self.creation_time, '%Y-%m-%dT%H-%M-%S'), {"key": "@sphere", "stringValue": "PIPELINE"},
}, { {"key": "@version", "stringValue": "1"},
"key": "@id", {"key": "@userId", "stringValue": "924374875933"},
"stringValue": self.pipeline_id, {"key": "@accountId", "stringValue": "924374875933"},
}, { {"key": "uniqueId", "stringValue": self.unique_id},
"key": "@sphere", ],
"stringValue": "PIPELINE"
}, {
"key": "@version",
"stringValue": "1"
}, {
"key": "@userId",
"stringValue": "924374875933"
}, {
"key": "@accountId",
"stringValue": "924374875933"
}, {
"key": "uniqueId",
"stringValue": self.unique_id
}],
"name": self.name, "name": self.name,
"pipelineId": self.pipeline_id, "pipelineId": self.pipeline_id,
"tags": self.tags "tags": self.tags,
} }
def set_pipeline_objects(self, pipeline_objects): def set_pipeline_objects(self, pipeline_objects):
self.objects = [ self.objects = [
PipelineObject(pipeline_object['id'], pipeline_object[ PipelineObject(
'name'], pipeline_object['fields']) pipeline_object["id"],
pipeline_object["name"],
pipeline_object["fields"],
)
for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects) for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects)
] ]
@ -94,15 +75,19 @@ class Pipeline(BaseModel):
self.status = "SCHEDULED" self.status = "SCHEDULED"
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
datapipeline_backend = datapipeline_backends[region_name] datapipeline_backend = datapipeline_backends[region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
cloudformation_unique_id = "cf-" + properties["Name"] cloudformation_unique_id = "cf-" + properties["Name"]
pipeline = datapipeline_backend.create_pipeline( pipeline = datapipeline_backend.create_pipeline(
properties["Name"], cloudformation_unique_id) properties["Name"], cloudformation_unique_id
)
datapipeline_backend.put_pipeline_definition( datapipeline_backend.put_pipeline_definition(
pipeline.pipeline_id, properties["PipelineObjects"]) pipeline.pipeline_id, properties["PipelineObjects"]
)
if properties["Activate"]: if properties["Activate"]:
pipeline.activate() pipeline.activate()
@ -110,7 +95,6 @@ class Pipeline(BaseModel):
class DataPipelineBackend(BaseBackend): class DataPipelineBackend(BaseBackend):
def __init__(self): def __init__(self):
self.pipelines = OrderedDict() self.pipelines = OrderedDict()
@ -123,8 +107,11 @@ class DataPipelineBackend(BaseBackend):
return self.pipelines.values() return self.pipelines.values()
def describe_pipelines(self, pipeline_ids): def describe_pipelines(self, pipeline_ids):
pipelines = [pipeline for pipeline in self.pipelines.values( pipelines = [
) if pipeline.pipeline_id in pipeline_ids] pipeline
for pipeline in self.pipelines.values()
if pipeline.pipeline_id in pipeline_ids
]
return pipelines return pipelines
def get_pipeline(self, pipeline_id): def get_pipeline(self, pipeline_id):
@ -144,7 +131,8 @@ class DataPipelineBackend(BaseBackend):
def describe_objects(self, object_ids, pipeline_id): def describe_objects(self, object_ids, pipeline_id):
pipeline = self.get_pipeline(pipeline_id) pipeline = self.get_pipeline(pipeline_id)
pipeline_objects = [ pipeline_objects = [
pipeline_object for pipeline_object in pipeline.objects pipeline_object
for pipeline_object in pipeline.objects
if pipeline_object.object_id in object_ids if pipeline_object.object_id in object_ids
] ]
return pipeline_objects return pipeline_objects
@ -155,5 +143,11 @@ class DataPipelineBackend(BaseBackend):
datapipeline_backends = {} datapipeline_backends = {}
for region in boto.datapipeline.regions(): for region in Session().get_available_regions("datapipeline"):
datapipeline_backends[region.name] = DataPipelineBackend() datapipeline_backends[region] = DataPipelineBackend()
for region in Session().get_available_regions(
"datapipeline", partition_name="aws-us-gov"
):
datapipeline_backends[region] = DataPipelineBackend()
for region in Session().get_available_regions("datapipeline", partition_name="aws-cn"):
datapipeline_backends[region] = DataPipelineBackend(region)

View File

@ -7,7 +7,6 @@ from .models import datapipeline_backends
class DataPipelineResponse(BaseResponse): class DataPipelineResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
# TODO this should really be moved to core/responses.py # TODO this should really be moved to core/responses.py
@ -21,47 +20,47 @@ class DataPipelineResponse(BaseResponse):
return datapipeline_backends[self.region] return datapipeline_backends[self.region]
def create_pipeline(self): def create_pipeline(self):
name = self.parameters.get('name') name = self.parameters.get("name")
unique_id = self.parameters.get('uniqueId') unique_id = self.parameters.get("uniqueId")
description = self.parameters.get('description', '') description = self.parameters.get("description", "")
tags = self.parameters.get('tags', []) tags = self.parameters.get("tags", [])
pipeline = self.datapipeline_backend.create_pipeline(name, unique_id, description=description, tags=tags) pipeline = self.datapipeline_backend.create_pipeline(
return json.dumps({ name, unique_id, description=description, tags=tags
"pipelineId": pipeline.pipeline_id, )
}) return json.dumps({"pipelineId": pipeline.pipeline_id})
def list_pipelines(self): def list_pipelines(self):
pipelines = list(self.datapipeline_backend.list_pipelines()) pipelines = list(self.datapipeline_backend.list_pipelines())
pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines] pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines]
max_pipelines = 50 max_pipelines = 50
marker = self.parameters.get('marker') marker = self.parameters.get("marker")
if marker: if marker:
start = pipeline_ids.index(marker) + 1 start = pipeline_ids.index(marker) + 1
else: else:
start = 0 start = 0
pipelines_resp = pipelines[start:start + max_pipelines] pipelines_resp = pipelines[start : start + max_pipelines]
has_more_results = False has_more_results = False
marker = None marker = None
if start + max_pipelines < len(pipeline_ids) - 1: if start + max_pipelines < len(pipeline_ids) - 1:
has_more_results = True has_more_results = True
marker = pipelines_resp[-1].pipeline_id marker = pipelines_resp[-1].pipeline_id
return json.dumps({ return json.dumps(
"hasMoreResults": has_more_results, {
"marker": marker, "hasMoreResults": has_more_results,
"pipelineIdList": [ "marker": marker,
pipeline.to_meta_json() for pipeline in pipelines_resp "pipelineIdList": [
] pipeline.to_meta_json() for pipeline in pipelines_resp
}) ],
}
)
def describe_pipelines(self): def describe_pipelines(self):
pipeline_ids = self.parameters["pipelineIds"] pipeline_ids = self.parameters["pipelineIds"]
pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids) pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids)
return json.dumps({ return json.dumps(
"pipelineDescriptionList": [ {"pipelineDescriptionList": [pipeline.to_json() for pipeline in pipelines]}
pipeline.to_json() for pipeline in pipelines )
]
})
def delete_pipeline(self): def delete_pipeline(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
@ -72,31 +71,38 @@ class DataPipelineResponse(BaseResponse):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
pipeline_objects = self.parameters["pipelineObjects"] pipeline_objects = self.parameters["pipelineObjects"]
self.datapipeline_backend.put_pipeline_definition( self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects)
pipeline_id, pipeline_objects)
return json.dumps({"errored": False}) return json.dumps({"errored": False})
def get_pipeline_definition(self): def get_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
pipeline_definition = self.datapipeline_backend.get_pipeline_definition( pipeline_definition = self.datapipeline_backend.get_pipeline_definition(
pipeline_id) pipeline_id
return json.dumps({ )
"pipelineObjects": [pipeline_object.to_json() for pipeline_object in pipeline_definition] return json.dumps(
}) {
"pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_definition
]
}
)
def describe_objects(self): def describe_objects(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
object_ids = self.parameters["objectIds"] object_ids = self.parameters["objectIds"]
pipeline_objects = self.datapipeline_backend.describe_objects( pipeline_objects = self.datapipeline_backend.describe_objects(
object_ids, pipeline_id) object_ids, pipeline_id
return json.dumps({ )
"hasMoreResults": False, return json.dumps(
"marker": None, {
"pipelineObjects": [ "hasMoreResults": False,
pipeline_object.to_json() for pipeline_object in pipeline_objects "marker": None,
] "pipelineObjects": [
}) pipeline_object.to_json() for pipeline_object in pipeline_objects
],
}
)
def activate_pipeline(self): def activate_pipeline(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import DataPipelineResponse from .responses import DataPipelineResponse
url_bases = [ url_bases = ["https?://datapipeline.(.+).amazonaws.com"]
"https?://datapipeline.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": DataPipelineResponse.dispatch}
'{0}/$': DataPipelineResponse.dispatch,
}

View File

@ -1,5 +1,5 @@
import collections
import six import six
from moto.compat import collections_abc
from moto.core.utils import get_random_hex from moto.core.utils import get_random_hex
@ -8,13 +8,15 @@ def get_random_pipeline_id():
def remove_capitalization_of_dict_keys(obj): def remove_capitalization_of_dict_keys(obj):
if isinstance(obj, collections.Mapping): if isinstance(obj, collections_abc.Mapping):
result = obj.__class__() result = obj.__class__()
for key, value in obj.items(): for key, value in obj.items():
normalized_key = key[:1].lower() + key[1:] normalized_key = key[:1].lower() + key[1:]
result[normalized_key] = remove_capitalization_of_dict_keys(value) result[normalized_key] = remove_capitalization_of_dict_keys(value)
return result return result
elif isinstance(obj, collections.Iterable) and not isinstance(obj, six.string_types): elif isinstance(obj, collections_abc.Iterable) and not isinstance(
obj, six.string_types
):
result = obj.__class__() result = obj.__class__()
for item in obj: for item in obj:
result += (remove_capitalization_of_dict_keys(item),) result += (remove_capitalization_of_dict_keys(item),)

View File

@ -0,0 +1,8 @@
from __future__ import unicode_literals
from ..core.models import base_decorator, deprecated_base_decorator
from .models import datasync_backends
datasync_backend = datasync_backends["us-east-1"]
mock_datasync = base_decorator(datasync_backends)
mock_datasync_deprecated = deprecated_base_decorator(datasync_backends)

View File

@ -0,0 +1,15 @@
from __future__ import unicode_literals
from moto.core.exceptions import JsonRESTError
class DataSyncClientError(JsonRESTError):
code = 400
class InvalidRequestException(DataSyncClientError):
def __init__(self, msg=None):
self.code = 400
super(InvalidRequestException, self).__init__(
"InvalidRequestException", msg or "The request is not valid."
)

235
moto/datasync/models.py Normal file
View File

@ -0,0 +1,235 @@
from boto3 import Session
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from .exceptions import InvalidRequestException
class Location(BaseModel):
def __init__(
self, location_uri, region_name=None, typ=None, metadata=None, arn_counter=0
):
self.uri = location_uri
self.region_name = region_name
self.metadata = metadata
self.typ = typ
# Generate ARN
self.arn = "arn:aws:datasync:{0}:111222333444:location/loc-{1}".format(
region_name, str(arn_counter).zfill(17)
)
class Task(BaseModel):
def __init__(
self,
source_location_arn,
destination_location_arn,
name,
region_name,
arn_counter=0,
metadata=None,
):
self.source_location_arn = source_location_arn
self.destination_location_arn = destination_location_arn
self.name = name
self.metadata = metadata
# For simplicity Tasks are either available or running
self.status = "AVAILABLE"
self.current_task_execution_arn = None
# Generate ARN
self.arn = "arn:aws:datasync:{0}:111222333444:task/task-{1}".format(
region_name, str(arn_counter).zfill(17)
)
class TaskExecution(BaseModel):
# For simplicity, task_execution can never fail
# Some documentation refers to this list:
# 'Status': 'QUEUED'|'LAUNCHING'|'PREPARING'|'TRANSFERRING'|'VERIFYING'|'SUCCESS'|'ERROR'
# Others refers to this list:
# INITIALIZING | PREPARING | TRANSFERRING | VERIFYING | SUCCESS/FAILURE
# Checking with AWS Support...
TASK_EXECUTION_INTERMEDIATE_STATES = (
"INITIALIZING",
# 'QUEUED', 'LAUNCHING',
"PREPARING",
"TRANSFERRING",
"VERIFYING",
)
TASK_EXECUTION_FAILURE_STATES = ("ERROR",)
TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",)
# Also COMPLETED state?
def __init__(self, task_arn, arn_counter=0):
self.task_arn = task_arn
self.arn = "{0}/execution/exec-{1}".format(task_arn, str(arn_counter).zfill(17))
self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[0]
# Simulate a task execution
def iterate_status(self):
if self.status in self.TASK_EXECUTION_FAILURE_STATES:
return
if self.status in self.TASK_EXECUTION_SUCCESS_STATES:
return
if self.status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
for i, status in enumerate(self.TASK_EXECUTION_INTERMEDIATE_STATES):
if status == self.status:
if i < len(self.TASK_EXECUTION_INTERMEDIATE_STATES) - 1:
self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[i + 1]
else:
self.status = self.TASK_EXECUTION_SUCCESS_STATES[0]
return
raise Exception(
"TaskExecution.iterate_status: Unknown status={0}".format(self.status)
)
def cancel(self):
if self.status not in self.TASK_EXECUTION_INTERMEDIATE_STATES:
raise InvalidRequestException(
"Sync task cannot be cancelled in its current status: {0}".format(
self.status
)
)
self.status = "ERROR"
class DataSyncBackend(BaseBackend):
def __init__(self, region_name):
self.region_name = region_name
# Always increase when new things are created
# This ensures uniqueness
self.arn_counter = 0
self.locations = OrderedDict()
self.tasks = OrderedDict()
self.task_executions = OrderedDict()
def reset(self):
region_name = self.region_name
self._reset_model_refs()
self.__dict__ = {}
self.__init__(region_name)
def create_location(self, location_uri, typ=None, metadata=None):
"""
# AWS DataSync allows for duplicate LocationUris
for arn, location in self.locations.items():
if location.uri == location_uri:
raise Exception('Location already exists')
"""
if not typ:
raise Exception("Location type must be specified")
self.arn_counter = self.arn_counter + 1
location = Location(
location_uri,
region_name=self.region_name,
arn_counter=self.arn_counter,
metadata=metadata,
typ=typ,
)
self.locations[location.arn] = location
return location.arn
def _get_location(self, location_arn, typ):
if location_arn not in self.locations:
raise InvalidRequestException(
"Location {0} is not found.".format(location_arn)
)
location = self.locations[location_arn]
if location.typ != typ:
raise InvalidRequestException(
"Invalid Location type: {0}".format(location.typ)
)
return location
def delete_location(self, location_arn):
if location_arn in self.locations:
del self.locations[location_arn]
else:
raise InvalidRequestException
def create_task(
self, source_location_arn, destination_location_arn, name, metadata=None
):
if source_location_arn not in self.locations:
raise InvalidRequestException(
"Location {0} not found.".format(source_location_arn)
)
if destination_location_arn not in self.locations:
raise InvalidRequestException(
"Location {0} not found.".format(destination_location_arn)
)
self.arn_counter = self.arn_counter + 1
task = Task(
source_location_arn,
destination_location_arn,
name,
region_name=self.region_name,
arn_counter=self.arn_counter,
metadata=metadata,
)
self.tasks[task.arn] = task
return task.arn
def _get_task(self, task_arn):
if task_arn in self.tasks:
return self.tasks[task_arn]
else:
raise InvalidRequestException
def update_task(self, task_arn, name, metadata):
if task_arn in self.tasks:
task = self.tasks[task_arn]
task.name = name
task.metadata = metadata
else:
raise InvalidRequestException(
"Sync task {0} is not found.".format(task_arn)
)
def delete_task(self, task_arn):
if task_arn in self.tasks:
del self.tasks[task_arn]
else:
raise InvalidRequestException
def start_task_execution(self, task_arn):
self.arn_counter = self.arn_counter + 1
if task_arn in self.tasks:
task = self.tasks[task_arn]
if task.status == "AVAILABLE":
task_execution = TaskExecution(task_arn, arn_counter=self.arn_counter)
self.task_executions[task_execution.arn] = task_execution
self.tasks[task_arn].current_task_execution_arn = task_execution.arn
self.tasks[task_arn].status = "RUNNING"
return task_execution.arn
raise InvalidRequestException("Invalid request.")
def _get_task_execution(self, task_execution_arn):
if task_execution_arn in self.task_executions:
return self.task_executions[task_execution_arn]
else:
raise InvalidRequestException
def cancel_task_execution(self, task_execution_arn):
if task_execution_arn in self.task_executions:
task_execution = self.task_executions[task_execution_arn]
task_execution.cancel()
task_arn = task_execution.task_arn
self.tasks[task_arn].current_task_execution_arn = None
self.tasks[task_arn].status = "AVAILABLE"
return
raise InvalidRequestException(
"Sync task {0} is not found.".format(task_execution_arn)
)
datasync_backends = {}
for region in Session().get_available_regions("datasync"):
datasync_backends[region] = DataSyncBackend(region)
for region in Session().get_available_regions("datasync", partition_name="aws-us-gov"):
datasync_backends[region] = DataSyncBackend(region)
for region in Session().get_available_regions("datasync", partition_name="aws-cn"):
datasync_backends[region] = DataSyncBackend(region)

162
moto/datasync/responses.py Normal file
View File

@ -0,0 +1,162 @@
import json
from moto.core.responses import BaseResponse
from .models import datasync_backends
class DataSyncResponse(BaseResponse):
@property
def datasync_backend(self):
return datasync_backends[self.region]
def list_locations(self):
locations = list()
for arn, location in self.datasync_backend.locations.items():
locations.append({"LocationArn": location.arn, "LocationUri": location.uri})
return json.dumps({"Locations": locations})
def _get_location(self, location_arn, typ):
return self.datasync_backend._get_location(location_arn, typ)
def create_location_s3(self):
# s3://bucket_name/folder/
s3_bucket_arn = self._get_param("S3BucketArn")
subdirectory = self._get_param("Subdirectory")
metadata = {"S3Config": self._get_param("S3Config")}
location_uri_elts = ["s3:/", s3_bucket_arn.split(":")[-1]]
if subdirectory:
location_uri_elts.append(subdirectory)
location_uri = "/".join(location_uri_elts)
arn = self.datasync_backend.create_location(
location_uri, metadata=metadata, typ="S3"
)
return json.dumps({"LocationArn": arn})
def describe_location_s3(self):
location_arn = self._get_param("LocationArn")
location = self._get_location(location_arn, typ="S3")
return json.dumps(
{
"LocationArn": location.arn,
"LocationUri": location.uri,
"S3Config": location.metadata["S3Config"],
}
)
def create_location_smb(self):
# smb://smb.share.fqdn/AWS_Test/
subdirectory = self._get_param("Subdirectory")
server_hostname = self._get_param("ServerHostname")
metadata = {
"AgentArns": self._get_param("AgentArns"),
"User": self._get_param("User"),
"Domain": self._get_param("Domain"),
"MountOptions": self._get_param("MountOptions"),
}
location_uri = "/".join(["smb:/", server_hostname, subdirectory])
arn = self.datasync_backend.create_location(
location_uri, metadata=metadata, typ="SMB"
)
return json.dumps({"LocationArn": arn})
def describe_location_smb(self):
location_arn = self._get_param("LocationArn")
location = self._get_location(location_arn, typ="SMB")
return json.dumps(
{
"LocationArn": location.arn,
"LocationUri": location.uri,
"AgentArns": location.metadata["AgentArns"],
"User": location.metadata["User"],
"Domain": location.metadata["Domain"],
"MountOptions": location.metadata["MountOptions"],
}
)
def delete_location(self):
location_arn = self._get_param("LocationArn")
self.datasync_backend.delete_location(location_arn)
return json.dumps({})
def create_task(self):
destination_location_arn = self._get_param("DestinationLocationArn")
source_location_arn = self._get_param("SourceLocationArn")
name = self._get_param("Name")
metadata = {
"CloudWatchLogGroupArn": self._get_param("CloudWatchLogGroupArn"),
"Options": self._get_param("Options"),
"Excludes": self._get_param("Excludes"),
"Tags": self._get_param("Tags"),
}
arn = self.datasync_backend.create_task(
source_location_arn, destination_location_arn, name, metadata=metadata
)
return json.dumps({"TaskArn": arn})
def update_task(self):
task_arn = self._get_param("TaskArn")
self.datasync_backend.update_task(
task_arn,
name=self._get_param("Name"),
metadata={
"CloudWatchLogGroupArn": self._get_param("CloudWatchLogGroupArn"),
"Options": self._get_param("Options"),
"Excludes": self._get_param("Excludes"),
"Tags": self._get_param("Tags"),
},
)
return json.dumps({})
def list_tasks(self):
tasks = list()
for arn, task in self.datasync_backend.tasks.items():
tasks.append(
{"Name": task.name, "Status": task.status, "TaskArn": task.arn}
)
return json.dumps({"Tasks": tasks})
def delete_task(self):
task_arn = self._get_param("TaskArn")
self.datasync_backend.delete_task(task_arn)
return json.dumps({})
def describe_task(self):
task_arn = self._get_param("TaskArn")
task = self.datasync_backend._get_task(task_arn)
return json.dumps(
{
"TaskArn": task.arn,
"Status": task.status,
"Name": task.name,
"CurrentTaskExecutionArn": task.current_task_execution_arn,
"SourceLocationArn": task.source_location_arn,
"DestinationLocationArn": task.destination_location_arn,
"CloudWatchLogGroupArn": task.metadata["CloudWatchLogGroupArn"],
"Options": task.metadata["Options"],
"Excludes": task.metadata["Excludes"],
}
)
def start_task_execution(self):
task_arn = self._get_param("TaskArn")
arn = self.datasync_backend.start_task_execution(task_arn)
return json.dumps({"TaskExecutionArn": arn})
def cancel_task_execution(self):
task_execution_arn = self._get_param("TaskExecutionArn")
self.datasync_backend.cancel_task_execution(task_execution_arn)
return json.dumps({})
def describe_task_execution(self):
task_execution_arn = self._get_param("TaskExecutionArn")
task_execution = self.datasync_backend._get_task_execution(task_execution_arn)
result = json.dumps(
{"TaskExecutionArn": task_execution.arn, "Status": task_execution.status}
)
if task_execution.status == "SUCCESS":
self.datasync_backend.tasks[task_execution.task_arn].status = "AVAILABLE"
# Simulate task being executed
task_execution.iterate_status()
return result

7
moto/datasync/urls.py Normal file
View File

@ -0,0 +1,7 @@
from __future__ import unicode_literals
from .responses import DataSyncResponse
url_bases = ["https?://(.*?)(datasync)(.*?).amazonaws.com"]
url_paths = {"{0}/$": DataSyncResponse.dispatch}

Some files were not shown because too many files have changed in this diff Show More