Merge remote-tracking branch 'upstream/master'

This commit is contained in:
mickeypash 2020-04-02 12:35:47 +01:00
commit d45e2d2e53
576 changed files with 84112 additions and 45577 deletions

3
.gitignore vendored
View File

@ -15,7 +15,10 @@ python_env
.ropeproject/
.pytest_cache/
venv/
env/
.python-version
.vscode/
tests/file.tmp
.eggs/
.mypy_cache/
*.tmp

View File

@ -1,12 +1,12 @@
dist: xenial
dist: bionic
language: python
sudo: false
services:
- docker
python:
- 2.7
- 3.6
- 3.7
- 3.8
env:
- TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true
@ -17,19 +17,29 @@ install:
python setup.py sdist
if [ "$TEST_SERVER_MODE" = "true" ]; then
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh &
if [ "$TRAVIS_PYTHON_VERSION" = "3.8" ]; then
# Python 3.8 does not provide Stretch images yet [1]
# [1] https://github.com/docker-library/python/issues/428
PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-buster
else
PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-stretch
fi
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh &
fi
travis_retry pip install -r requirements-dev.txt
travis_retry pip install boto==2.45.0
travis_retry pip install boto3
travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt
travis_retry pip install coverage==4.5.4
if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py
fi
before_script:
- if [[ $TRAVIS_PYTHON_VERSION == "3.7" ]]; then make lint; fi
script:
- make test
- make test-only
after_success:
- coveralls
before_deploy:

View File

@ -57,3 +57,4 @@ Moto is written by Steve Pulec with contributions from:
* [Bendeguz Acs](https://github.com/acsbendi)
* [Craig Anderson](https://github.com/craiga)
* [Robert Lewis](https://github.com/ralewis85)
* [Kyle Jones](https://github.com/Kerl1310)

View File

@ -1,6 +1,189 @@
Moto Changelog
===================
1.3.14
-----
General Changes:
* Support for Python 3.8
* Linting: Black is now enforced.
New Services:
* Athena
* Config
* DataSync
* Step Functions
New methods:
* Athena:
* create_work_group()
* list_work_groups()
* API Gateway:
* delete_stage()
* update_api_key()
* CloudWatch Logs
* list_tags_log_group()
* tag_log_group()
* untag_log_group()
* Config
* batch_get_resource_config()
* delete_aggregation_authorization()
* delete_configuration_aggregator()
* describe_aggregation_authorizations()
* describe_configuration_aggregators()
* get_resource_config_history()
* list_aggregate_discovered_resources() (For S3)
* list_discovered_resources() (For S3)
* put_aggregation_authorization()
* put_configuration_aggregator()
* Cognito
* assume_role_with_web_identity()
* describe_identity_pool()
* get_open_id_token()
* update_user_pool_domain()
* DataSync:
* cancel_task_execution()
* create_location()
* create_task()
* start_task_execution()
* EC2:
* create_launch_template()
* create_launch_template_version()
* describe_launch_template_versions()
* describe_launch_templates()
* ECS
* decrypt()
* encrypt()
* generate_data_key_without_plaintext()
* generate_random()
* re_encrypt()
* Glue
* batch_get_partition()
* IAM
* create_open_id_connect_provider()
* create_virtual_mfa_device()
* delete_account_password_policy()
* delete_open_id_connect_provider()
* delete_policy()
* delete_virtual_mfa_device()
* get_account_password_policy()
* get_open_id_connect_provider()
* list_open_id_connect_providers()
* list_virtual_mfa_devices()
* update_account_password_policy()
* Lambda
* create_event_source_mapping()
* delete_event_source_mapping()
* get_event_source_mapping()
* list_event_source_mappings()
* update_configuration()
* update_event_source_mapping()
* update_function_code()
* KMS
* decrypt()
* encrypt()
* generate_data_key_without_plaintext()
* generate_random()
* re_encrypt()
* SES
* send_templated_email()
* SNS
* add_permission()
* list_tags_for_resource()
* remove_permission()
* tag_resource()
* untag_resource()
* SSM
* describe_parameters()
* get_parameter_history()
* Step Functions
* create_state_machine()
* delete_state_machine()
* describe_execution()
* describe_state_machine()
* describe_state_machine_for_execution()
* list_executions()
* list_state_machines()
* list_tags_for_resource()
* start_execution()
* stop_execution()
SQS
* list_queue_tags()
* send_message_batch()
General updates:
* API Gateway:
* Now generates valid IDs
* API Keys, Usage Plans now support tags
* ACM:
* list_certificates() accepts the status parameter
* Batch:
* submit_job() can now be called with job name
* CloudWatch Events
* Multi-region support
* CloudWatch Logs
* get_log_events() now supports pagination
* Cognito:
* Now throws UsernameExistsException for known users
* DynamoDB
* update_item() now supports lists, the list_append-operator and removing nested items
* delete_item() now supports condition expressions
* get_item() now supports projection expression
* Enforces 400KB item size
* Validation on duplicate keys in batch_get_item()
* Validation on AttributeDefinitions on create_table()
* Validation on Query Key Expression
* Projection Expressions now support nested attributes
* EC2:
* Change DesiredCapacity behaviour for AutoScaling groups
* Extend list of supported EC2 ENI properties
* Create ASG from Instance now supported
* ASG attached to a terminated instance now recreate the instance of required
* Unify OwnerIDs
* ECS
* Task definition revision deregistration: remaining revisions now remain unchanged
* Fix created_at/updated_at format for deployments
* Support multiple regions
* ELB
* Return correct response then describing target health of stopped instances
* Target groups now longer show terminated instances
* 'fixed-response' now a supported action-type
* Now supports redirect: authenticate-cognito
* Kinesis FireHose
* Now supports ExtendedS3DestinationConfiguration
* KMS
* Now supports tags
* Organizations
* create_organization() now creates Master account
* Redshift
* Fix timezone problems when creating a cluster
* Support for enhanced_vpc_routing-parameter
* Route53
* Implemented UPSERT for change_resource_records
* S3:
* Support partNumber for head_object
* Support for INTELLIGENT_TIERING, GLACIER and DEEP_ARCHIVE
* Fix KeyCount attribute
* list_objects now supports pagination (next_marker)
* Support tagging for versioned objects
* STS
* Implement validation on policy length
* Lambda
* Support EventSourceMappings for SQS, DynamoDB
* get_function(), delete_function() now both support ARNs as parameters
* IAM
* Roles now support tags
* Policy Validation: SID can be empty
* Validate roles have no attachments when deleting
* SecretsManager
* Now supports binary secrets
* IOT
* update_thing_shadow validation
* delete_thing now also removed principals
* SQS
* Tags supported for create_queue()
1.3.7
-----

120
CONFIG_README.md Normal file
View File

@ -0,0 +1,120 @@
# AWS Config Querying Support in Moto
An experimental feature for AWS Config has been developed to provide AWS Config capabilities in your unit tests.
This feature is experimental as there are many services that are not yet supported and will require the community to add them in
over time. This page details how the feature works and how you can use it.
## What is this and why would I use this?
AWS Config is an AWS service that describes your AWS resource types and can track their changes over time. At this time, moto does not
have support for handling the configuration history changes, but it does have a few methods mocked out that can be immensely useful
for unit testing.
If you are developing automation that needs to pull against AWS Config, then this will help you write tests that can simulate your
code in production.
## How does this work?
The AWS Config capabilities in moto work by examining the state of resources that are created within moto, and then returning that data
in the way that AWS Config would return it (sans history). This will work by querying all of the moto backends (regions) for a given
resource type.
However, this will only work on resource types that have this enabled.
### Current enabled resource types:
1. S3
## Developer Guide
There are several pieces to this for adding new capabilities to moto:
1. Listing resources
1. Describing resources
For both, there are a number of pre-requisites:
### Base Components
In the `moto/core/models.py` file is a class named `ConfigQueryModel`. This is a base class that keeps track of all the
resource type backends.
At a minimum, resource types that have this enabled will have:
1. A `config.py` file that will import the resource type backends (from the `__init__.py`)
1. In the resource's `config.py`, an implementation of the `ConfigQueryModel` class with logic unique to the resource type
1. An instantiation of the `ConfigQueryModel`
1. In the `moto/config/models.py` file, import the `ConfigQueryModel` instantiation, and update `RESOURCE_MAP` to have a mapping of the AWS Config resource type
to the instantiation on the previous step (just imported).
An example of the above is implemented for S3. You can see that by looking at:
1. `moto/s3/config.py`
1. `moto/config/models.py`
As well as the corresponding unit tests in:
1. `tests/s3/test_s3.py`
1. `tests/config/test_config.py`
Note for unit testing, you will want to add a test to ensure that you can query all the resources effectively. For testing this feature,
the unit tests for the `ConfigQueryModel` will not make use of `boto` to create resources, such as S3 buckets. You will need to use the
backend model methods to provision the resources. This is to make tests compatible with the moto server. You should absolutely make tests
in the resource type to test listing and object fetching.
### Listing
S3 is currently the model implementation, but it also odd in that S3 is a global resource type with regional resource residency.
But for most resource types the following is true:
1. There are regional backends with their own sets of data
1. Config aggregation can pull data from any backend region -- we assume that everything lives in the same account
Implementing the listing capability will be different for each resource type. At a minimum, you will need to return a `List` of `Dict`s
that look like this:
```python
[
{
'type': 'AWS::The AWS Config data type',
'name': 'The name of the resource',
'id': 'The ID of the resource',
'region': 'The region of the resource -- if global, then you may want to have the calling logic pass in the
aggregator region in for the resource region -- or just us-east-1 :P'
}
, ...
]
```
It's recommended to read the comment for the `ConfigQueryModel`'s `list_config_service_resources` function in [base class here](moto/core/models.py).
^^ The AWS Config code will see this and format it correct for both aggregated and non-aggregated calls.
#### General implementation tips
The aggregation and non-aggregation querying can and should just use the same overall logic. The differences are:
1. Non-aggregated listing will specify the region-name of the resource backend `backend_region`
1. Aggregated listing will need to be able to list resource types across ALL backends and filter optionally by passing in `resource_region`.
An example of a working implementation of this is [S3](moto/s3/config.py).
Pagination should generally be able to pull out the resource across any region so should be sharded by `region-item-name` -- not done for S3
because S3 has a globally unique name space.
### Describing Resources
Fetching a resource's configuration has some similarities to listing resources, but it requires more work (to implement). Due to the
various ways that a resource can be configured, some work will need to be done to ensure that the Config dict returned is correct.
For most resource types the following is true:
1. There are regional backends with their own sets of data
1. Config aggregation can pull data from any backend region -- we assume that everything lives in the same account
The current implementation is for S3. S3 is very complex and depending on how the bucket is configured will depend on what Config will
return for it.
When implementing resource config fetching, you will need to return at a minimum `None` if the resource is not found, or a `dict` that looks
like what AWS Config would return.
It's recommended to read the comment for the `ConfigQueryModel` 's `get_config_resource` function in [base class here](moto/core/models.py).

File diff suppressed because it is too large Load Diff

View File

@ -14,12 +14,16 @@ init:
lint:
flake8 moto
black --check moto/ tests/
test: lint
test-only:
rm -f .coverage
rm -rf cover
@nosetests -sv --with-coverage --cover-html ./tests/ $(TEST_EXCLUDE)
test: lint test-only
test_server:
@TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/
@ -27,7 +31,8 @@ aws_managed_policies:
scripts/update_managed_policies.py
upload_pypi_artifact:
python setup.py sdist bdist_wheel upload
python setup.py sdist bdist_wheel
twine upload dist/*
push_dockerhub_image:
docker build -t motoserver/moto .

149
README.md
View File

@ -7,9 +7,9 @@
[![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org)
![PyPI](https://img.shields.io/pypi/v/moto.svg)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
# In a nutshell
## In a nutshell
Moto is a library that allows your tests to easily mock out AWS Services.
@ -78,6 +78,7 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
| Cognito Identity Provider | @mock_cognitoidp | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Config | @mock_config | basic endpoints done |
| | | core endpoints done |
|-------------------------------------------------------------------------------------|
| Data Pipeline | @mock_datapipeline | basic endpoints done |
|-------------------------------------------------------------------------------------|
@ -255,6 +256,140 @@ def test_my_model_save():
mock.stop()
```
## IAM-like Access Control
Moto also has the ability to authenticate and authorize actions, just like it's done by IAM in AWS. This functionality can be enabled by either setting the `INITIAL_NO_AUTH_ACTION_COUNT` environment variable or using the `set_initial_no_auth_action_count` decorator. Note that the current implementation is very basic, see [this file](https://github.com/spulec/moto/blob/master/moto/core/access_control.py) for more information.
### `INITIAL_NO_AUTH_ACTION_COUNT`
If this environment variable is set, moto will skip performing any authentication as many times as the variable's value, and only starts authenticating requests afterwards. If it is not set, it defaults to infinity, thus moto will never perform any authentication at all.
### `set_initial_no_auth_action_count`
This is a decorator that works similarly to the environment variable, but the settings are only valid in the function's scope. When the function returns, everything is restored.
```python
@set_initial_no_auth_action_count(4)
@mock_ec2
def test_describe_instances_allowed():
policy_document = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:Describe*",
"Resource": "*"
}
]
}
access_key = ...
# create access key for an IAM user/assumed role that has the policy above.
# this part should call __exactly__ 4 AWS actions, so that authentication and authorization starts exactly after this
client = boto3.client('ec2', region_name='us-east-1',
aws_access_key_id=access_key['AccessKeyId'],
aws_secret_access_key=access_key['SecretAccessKey'])
# if the IAM principal whose access key is used, does not have the permission to describe instances, this will fail
instances = client.describe_instances()['Reservations'][0]['Instances']
assert len(instances) == 0
```
See [the related test suite](https://github.com/spulec/moto/blob/master/tests/test_core/test_auth.py) for more examples.
## Experimental: AWS Config Querying
For details about the experimental AWS Config support please see the [AWS Config readme here](CONFIG_README.md).
## Very Important -- Recommended Usage
There are some important caveats to be aware of when using moto:
*Failure to follow these guidelines could result in your tests mutating your __REAL__ infrastructure!*
### How do I avoid tests from mutating my real infrastructure?
You need to ensure that the mocks are actually in place. Changes made to recent versions of `botocore`
have altered some of the mock behavior. In short, you need to ensure that you _always_ do the following:
1. Ensure that your tests have dummy environment variables set up:
export AWS_ACCESS_KEY_ID='testing'
export AWS_SECRET_ACCESS_KEY='testing'
export AWS_SECURITY_TOKEN='testing'
export AWS_SESSION_TOKEN='testing'
1. __VERY IMPORTANT__: ensure that you have your mocks set up __BEFORE__ your `boto3` client is established.
This can typically happen if you import a module that has a `boto3` client instantiated outside of a function.
See the pesky imports section below on how to work around this.
### Example on usage?
If you are a user of [pytest](https://pytest.org/en/latest/), you can leverage [pytest fixtures](https://pytest.org/en/latest/fixture.html#fixture)
to help set up your mocks and other AWS resources that you would need.
Here is an example:
```python
@pytest.fixture(scope='function')
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'
@pytest.fixture(scope='function')
def s3(aws_credentials):
with mock_s3():
yield boto3.client('s3', region_name='us-east-1')
@pytest.fixture(scope='function')
def sts(aws_credentials):
with mock_sts():
yield boto3.client('sts', region_name='us-east-1')
@pytest.fixture(scope='function')
def cloudwatch(aws_credentials):
with mock_cloudwatch():
yield boto3.client('cloudwatch', region_name='us-east-1')
... etc.
```
In the code sample above, all of the AWS/mocked fixtures take in a parameter of `aws_credentials`,
which sets the proper fake environment variables. The fake environment variables are used so that `botocore` doesn't try to locate real
credentials on your system.
Next, once you need to do anything with the mocked AWS environment, do something like:
```python
def test_create_bucket(s3):
# s3 is a fixture defined above that yields a boto3 s3 client.
# Feel free to instantiate another boto3 S3 client -- Keep note of the region though.
s3.create_bucket(Bucket="somebucket")
result = s3.list_buckets()
assert len(result['Buckets']) == 1
assert result['Buckets'][0]['Name'] == 'somebucket'
```
### What about those pesky imports?
Recall earlier, it was mentioned that mocks should be established __BEFORE__ the clients are set up. One way
to avoid import issues is to make use of local Python imports -- i.e. import the module inside of the unit
test you want to run vs. importing at the top of the file.
Example:
```python
def test_something(s3):
from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test
# ^^ Importing here ensures that the mock has been established.
some_func() # The mock has been established from the "s3" pytest fixture, so this function that uses
# a package-level S3 client will properly use the mock and not reach out to AWS.
```
### Other caveats
For Tox, Travis CI, and other build systems, you might need to also perform a `touch ~/.aws/credentials`
command before running the tests. As long as that file is present (empty preferably) and the environment
variables above are set, you should be good to go.
## Stand-alone Server Mode
Moto also has a stand-alone server mode. This allows you to utilize
@ -315,6 +450,16 @@ boto3.resource(
)
```
### Caveats
The standalone server has some caveats with some services. The following services
require that you update your hosts file for your code to work properly:
1. `s3-control`
For the above services, this is required because the hostname is in the form of `AWS_ACCOUNT_ID.localhost`.
As a result, you need to add that entry to your host file for your tests to function properly.
## Install

View File

@ -30,6 +30,8 @@ Currently implemented Services:
+-----------------------+---------------------+-----------------------------------+
| Data Pipeline | @mock_datapipeline | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+
| DataSync | @mock_datasync | some endpoints done |
+-----------------------+---------------------+-----------------------------------+
| - DynamoDB | - @mock_dynamodb | - core endpoints done |
| - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes|
+-----------------------+---------------------+-----------------------------------+

View File

@ -56,9 +56,10 @@ author = 'Steve Pulec'
# built documents.
#
# The short X.Y version.
version = '0.4.10'
import moto
version = moto.__version__
# The full version, including alpha/beta/rc tags.
release = '0.4.10'
release = moto.__version__
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.

View File

@ -24,8 +24,7 @@ For example, we have the following code we want to test:
.. sourcecode:: python
import boto
from boto.s3.key import Key
import boto3
class MyModel(object):
def __init__(self, name, value):
@ -33,11 +32,8 @@ For example, we have the following code we want to test:
self.value = value
def save(self):
conn = boto.connect_s3()
bucket = conn.get_bucket('mybucket')
k = Key(bucket)
k.key = self.name
k.set_contents_from_string(self.value)
s3 = boto3.client('s3', region_name='us-east-1')
s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value)
There are several ways to do this, but you should keep in mind that Moto creates a full, blank environment.
@ -48,20 +44,23 @@ With a decorator wrapping, all the calls to S3 are automatically mocked out.
.. sourcecode:: python
import boto
import boto3
from moto import mock_s3
from mymodule import MyModel
@mock_s3
def test_my_model_save():
conn = boto.connect_s3()
conn = boto3.resource('s3', region_name='us-east-1')
# We need to create the bucket since this is all in Moto's 'virtual' AWS account
conn.create_bucket('mybucket')
conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome')
model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome'
body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
Context manager
~~~~~~~~~~~~~~~
@ -72,13 +71,16 @@ Same as the Decorator, every call inside the ``with`` statement is mocked out.
def test_my_model_save():
with mock_s3():
conn = boto.connect_s3()
conn.create_bucket('mybucket')
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome')
model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome'
body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
Raw
~~~
@ -91,13 +93,16 @@ You can also start and stop the mocking manually.
mock = mock_s3()
mock.start()
conn = boto.connect_s3()
conn.create_bucket('mybucket')
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome')
model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome'
body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
mock.stop()

View File

@ -76,7 +76,7 @@ Currently implemented Services:
+---------------------------+-----------------------+------------------------------------+
| Logs | @mock_logs | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Organizations | @mock_organizations | some core edpoints done |
| Organizations | @mock_organizations | some core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Polly | @mock_polly | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
@ -94,6 +94,8 @@ Currently implemented Services:
+---------------------------+-----------------------+------------------------------------+
| SES | @mock_ses | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SFN | @mock_stepfunctions | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SNS | @mock_sns | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SQS | @mock_sqs | core endpoints done |

View File

@ -1,62 +1,77 @@
from __future__ import unicode_literals
import logging
from .acm import mock_acm # noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # noqa
from .athena import mock_athena # noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # noqa
from .batch import mock_batch # noqa
from .cloudformation import mock_cloudformation # noqa
from .cloudformation import mock_cloudformation_deprecated # noqa
from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # noqa
from .codecommit import mock_codecommit # noqa
from .codepipeline import mock_codepipeline # noqa
from .cognitoidentity import mock_cognitoidentity # noqa
from .cognitoidentity import mock_cognitoidentity_deprecated # noqa
from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # noqa
from .config import mock_config # noqa
from .datapipeline import mock_datapipeline # noqa
from .datapipeline import mock_datapipeline_deprecated # noqa
from .datasync import mock_datasync # noqa
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # noqa
from .dynamodbstreams import mock_dynamodbstreams # noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # noqa
from .ec2_instance_connect import mock_ec2_instance_connect # noqa
from .ecr import mock_ecr, mock_ecr_deprecated # noqa
from .ecs import mock_ecs, mock_ecs_deprecated # noqa
from .elb import mock_elb, mock_elb_deprecated # noqa
from .elbv2 import mock_elbv2 # noqa
from .emr import mock_emr, mock_emr_deprecated # noqa
from .events import mock_events # noqa
from .glacier import mock_glacier, mock_glacier_deprecated # noqa
from .glue import mock_glue # noqa
from .iam import mock_iam, mock_iam_deprecated # noqa
from .iot import mock_iot # noqa
from .iotdata import mock_iotdata # noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # noqa
from .kms import mock_kms, mock_kms_deprecated # noqa
from .logs import mock_logs, mock_logs_deprecated # noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # noqa
from .organizations import mock_organizations # noqa
from .polly import mock_polly # noqa
from .rds import mock_rds, mock_rds_deprecated # noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # noqa
from .redshift import mock_redshift, mock_redshift_deprecated # noqa
from .resourcegroups import mock_resourcegroups # noqa
from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # noqa
from .route53 import mock_route53, mock_route53_deprecated # noqa
from .s3 import mock_s3, mock_s3_deprecated # noqa
from .secretsmanager import mock_secretsmanager # noqa
from .ses import mock_ses, mock_ses_deprecated # noqa
from .sns import mock_sns, mock_sns_deprecated # noqa
from .sqs import mock_sqs, mock_sqs_deprecated # noqa
from .ssm import mock_ssm # noqa
from .stepfunctions import mock_stepfunctions # noqa
from .sts import mock_sts, mock_sts_deprecated # noqa
from .swf import mock_swf, mock_swf_deprecated # noqa
from .xray import XRaySegment, mock_xray, mock_xray_client # noqa
# import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto'
__version__ = '1.3.14.dev'
from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # flake8: noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # flake8: noqa
from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # flake8: noqa
from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # flake8: noqa
from .cognitoidentity import mock_cognitoidentity, mock_cognitoidentity_deprecated # flake8: noqa
from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # flake8: noqa
from .config import mock_config # flake8: noqa
from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # flake8: noqa
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # flake8: noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # flake8: noqa
from .dynamodbstreams import mock_dynamodbstreams # flake8: noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # flake8: noqa
from .ecr import mock_ecr, mock_ecr_deprecated # flake8: noqa
from .ecs import mock_ecs, mock_ecs_deprecated # flake8: noqa
from .elb import mock_elb, mock_elb_deprecated # flake8: noqa
from .elbv2 import mock_elbv2 # flake8: noqa
from .emr import mock_emr, mock_emr_deprecated # flake8: noqa
from .events import mock_events # flake8: noqa
from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa
from .glue import mock_glue # flake8: noqa
from .iam import mock_iam, mock_iam_deprecated # flake8: noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa
from .kms import mock_kms, mock_kms_deprecated # flake8: noqa
from .organizations import mock_organizations # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .polly import mock_polly # flake8: noqa
from .rds import mock_rds, mock_rds_deprecated # flake8: noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa
from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa
from .resourcegroups import mock_resourcegroups # flake8: noqa
from .s3 import mock_s3, mock_s3_deprecated # flake8: noqa
from .ses import mock_ses, mock_ses_deprecated # flake8: noqa
from .secretsmanager import mock_secretsmanager # flake8: noqa
from .sns import mock_sns, mock_sns_deprecated # flake8: noqa
from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa
from .sts import mock_sts, mock_sts_deprecated # flake8: noqa
from .ssm import mock_ssm # flake8: noqa
from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa
from .swf import mock_swf, mock_swf_deprecated # flake8: noqa
from .xray import mock_xray, mock_xray_client, XRaySegment # flake8: noqa
from .logs import mock_logs, mock_logs_deprecated # flake8: noqa
from .batch import mock_batch # flake8: noqa
from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # flake8: noqa
from .iot import mock_iot # flake8: noqa
from .iotdata import mock_iotdata # flake8: noqa
__title__ = "moto"
__version__ = "1.3.15.dev"
try:
# Need to monkey-patch botocore requests back to underlying urllib3 classes
from botocore.awsrequest import HTTPSConnectionPool, HTTPConnectionPool, HTTPConnection, VerifiedHTTPSConnection
from botocore.awsrequest import (
HTTPSConnectionPool,
HTTPConnectionPool,
HTTPConnection,
VerifiedHTTPSConnection,
)
except ImportError:
pass
else:

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import acm_backends
from ..core.models import base_decorator
acm_backend = acm_backends['us-east-1']
acm_backend = acm_backends["us-east-1"]
mock_acm = base_decorator(acm_backends)

View File

@ -13,8 +13,9 @@ import cryptography.hazmat.primitives.asymmetric.rsa
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.backends import default_backend
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
DEFAULT_ACCOUNT_ID = 123456789012
GOOGLE_ROOT_CA = b"""-----BEGIN CERTIFICATE-----
MIIEKDCCAxCgAwIBAgIQAQAhJYiw+lmnd+8Fe2Yn3zANBgkqhkiG9w0BAQsFADBC
MQswCQYDVQQGEwJVUzEWMBQGA1UEChMNR2VvVHJ1c3QgSW5jLjEbMBkGA1UEAxMS
@ -57,20 +58,29 @@ class AWSError(Exception):
self.message = message
def response(self):
resp = {'__type': self.TYPE, 'message': self.message}
resp = {"__type": self.TYPE, "message": self.message}
return json.dumps(resp), dict(status=self.STATUS)
class AWSValidationException(AWSError):
TYPE = 'ValidationException'
TYPE = "ValidationException"
class AWSResourceNotFoundException(AWSError):
TYPE = 'ResourceNotFoundException'
TYPE = "ResourceNotFoundException"
class CertBundle(BaseModel):
def __init__(self, certificate, private_key, chain=None, region='us-east-1', arn=None, cert_type='IMPORTED', cert_status='ISSUED'):
def __init__(
self,
certificate,
private_key,
chain=None,
region="us-east-1",
arn=None,
cert_type="IMPORTED",
cert_status="ISSUED",
):
self.created_at = datetime.datetime.now()
self.cert = certificate
self._cert = None
@ -87,7 +97,7 @@ class CertBundle(BaseModel):
if self.chain is None:
self.chain = GOOGLE_ROOT_CA
else:
self.chain += b'\n' + GOOGLE_ROOT_CA
self.chain += b"\n" + GOOGLE_ROOT_CA
# Takes care of PEM checking
self.validate_pk()
@ -105,7 +115,7 @@ class CertBundle(BaseModel):
self.arn = arn
@classmethod
def generate_cert(cls, domain_name, sans=None):
def generate_cert(cls, domain_name, region, sans=None):
if sans is None:
sans = set()
else:
@ -114,149 +124,209 @@ class CertBundle(BaseModel):
sans.add(domain_name)
sans = [cryptography.x509.DNSName(item) for item in sans]
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend())
subject = cryptography.x509.Name([
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, u"CA"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.LOCALITY_NAME, u"San Francisco"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"My Company"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, domain_name),
])
issuer = cryptography.x509.Name([ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"Amazon"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u"Server CA 1B"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, u"Amazon"),
])
cert = cryptography.x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
key.public_key()
).serial_number(
cryptography.x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=365)
).add_extension(
cryptography.x509.SubjectAlternativeName(sans),
critical=False,
).sign(key, hashes.SHA512(), default_backend())
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
subject = cryptography.x509.Name(
[
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COUNTRY_NAME, "US"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, domain_name
),
]
)
issuer = cryptography.x509.Name(
[ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COUNTRY_NAME, "US"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, "Amazon"
),
]
)
cert = (
cryptography.x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(cryptography.x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.add_extension(
cryptography.x509.SubjectAlternativeName(sans), critical=False
)
.sign(key, hashes.SHA512(), default_backend())
)
cert_armored = cert.public_bytes(serialization.Encoding.PEM)
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
encryption_algorithm=serialization.NoEncryption(),
)
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION')
return cls(
cert_armored,
private_key,
cert_type="AMAZON_ISSUED",
cert_status="PENDING_VALIDATION",
region=region,
)
def validate_pk(self):
try:
self._key = serialization.load_pem_private_key(self.key, password=None, backend=default_backend())
self._key = serialization.load_pem_private_key(
self.key, password=None, backend=default_backend()
)
if self._key.key_size > 2048:
AWSValidationException('The private key length is not supported. Only 1024-bit and 2048-bit are allowed.')
AWSValidationException(
"The private key length is not supported. Only 1024-bit and 2048-bit are allowed."
)
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The private key is not PEM-encoded or is not valid.')
raise AWSValidationException(
"The private key is not PEM-encoded or is not valid."
)
def validate_certificate(self):
try:
self._cert = cryptography.x509.load_pem_x509_certificate(self.cert, default_backend())
self._cert = cryptography.x509.load_pem_x509_certificate(
self.cert, default_backend()
)
now = datetime.datetime.utcnow()
if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate has expired, is not valid.')
raise AWSValidationException(
"The certificate has expired, is not valid."
)
if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate is not in effect yet, is not valid.')
raise AWSValidationException(
"The certificate is not in effect yet, is not valid."
)
# Extracting some common fields for ease of use
# Have to search through cert.subject for OIDs
self.common_name = self._cert.subject.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value
self.common_name = self._cert.subject.get_attributes_for_oid(
cryptography.x509.OID_COMMON_NAME
)[0].value
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.')
raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
)
def validate_chain(self):
try:
self._chain = []
for cert_armored in self.chain.split(b'-\n-'):
for cert_armored in self.chain.split(b"-\n-"):
# Would leave encoded but Py2 does not have raw binary strings
cert_armored = cert_armored.decode()
# Fix missing -'s on split
cert_armored = re.sub(r'^----B', '-----B', cert_armored)
cert_armored = re.sub(r'E----$', 'E-----', cert_armored)
cert = cryptography.x509.load_pem_x509_certificate(cert_armored.encode(), default_backend())
cert_armored = re.sub(r"^----B", "-----B", cert_armored)
cert_armored = re.sub(r"E----$", "E-----", cert_armored)
cert = cryptography.x509.load_pem_x509_certificate(
cert_armored.encode(), default_backend()
)
self._chain.append(cert)
now = datetime.datetime.now()
if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate chain has expired, is not valid.')
raise AWSValidationException(
"The certificate chain has expired, is not valid."
)
if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate chain is not in effect yet, is not valid.')
raise AWSValidationException(
"The certificate chain is not in effect yet, is not valid."
)
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.')
raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
)
def check(self):
# Basically, if the certificate is pending, and then checked again after 1 min
# It will appear as if its been validated
if self.type == 'AMAZON_ISSUED' and self.status == 'PENDING_VALIDATION' and \
(datetime.datetime.now() - self.created_at).total_seconds() > 60: # 1min
self.status = 'ISSUED'
if (
self.type == "AMAZON_ISSUED"
and self.status == "PENDING_VALIDATION"
and (datetime.datetime.now() - self.created_at).total_seconds() > 60
): # 1min
self.status = "ISSUED"
def describe(self):
# 'RenewalSummary': {}, # Only when cert is amazon issued
if self._key.key_size == 1024:
key_algo = 'RSA_1024'
key_algo = "RSA_1024"
elif self._key.key_size == 2048:
key_algo = 'RSA_2048'
key_algo = "RSA_2048"
else:
key_algo = 'EC_prime256v1'
key_algo = "EC_prime256v1"
# Look for SANs
san_obj = self._cert.extensions.get_extension_for_oid(cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME)
san_obj = self._cert.extensions.get_extension_for_oid(
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
)
sans = []
if san_obj is not None:
sans = [item.value for item in san_obj.value]
result = {
'Certificate': {
'CertificateArn': self.arn,
'DomainName': self.common_name,
'InUseBy': [],
'Issuer': self._cert.issuer.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value,
'KeyAlgorithm': key_algo,
'NotAfter': datetime_to_epoch(self._cert.not_valid_after),
'NotBefore': datetime_to_epoch(self._cert.not_valid_before),
'Serial': self._cert.serial_number,
'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''),
'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED.
'Subject': 'CN={0}'.format(self.common_name),
'SubjectAlternativeNames': sans,
'Type': self.type # One of IMPORTED, AMAZON_ISSUED
"Certificate": {
"CertificateArn": self.arn,
"DomainName": self.common_name,
"InUseBy": [],
"Issuer": self._cert.issuer.get_attributes_for_oid(
cryptography.x509.OID_COMMON_NAME
)[0].value,
"KeyAlgorithm": key_algo,
"NotAfter": datetime_to_epoch(self._cert.not_valid_after),
"NotBefore": datetime_to_epoch(self._cert.not_valid_before),
"Serial": self._cert.serial_number,
"SignatureAlgorithm": self._cert.signature_algorithm_oid._name.upper().replace(
"ENCRYPTION", ""
),
"Status": self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED.
"Subject": "CN={0}".format(self.common_name),
"SubjectAlternativeNames": sans,
"Type": self.type, # One of IMPORTED, AMAZON_ISSUED
}
}
if self.type == 'IMPORTED':
result['Certificate']['ImportedAt'] = datetime_to_epoch(self.created_at)
if self.type == "IMPORTED":
result["Certificate"]["ImportedAt"] = datetime_to_epoch(self.created_at)
else:
result['Certificate']['CreatedAt'] = datetime_to_epoch(self.created_at)
result['Certificate']['IssuedAt'] = datetime_to_epoch(self.created_at)
result["Certificate"]["CreatedAt"] = datetime_to_epoch(self.created_at)
result["Certificate"]["IssuedAt"] = datetime_to_epoch(self.created_at)
return result
@ -264,7 +334,7 @@ class CertBundle(BaseModel):
return self.arn
def __repr__(self):
return '<Certificate>'
return "<Certificate>"
class AWSCertificateManagerBackend(BaseBackend):
@ -281,7 +351,9 @@ class AWSCertificateManagerBackend(BaseBackend):
@staticmethod
def _arn_not_found(arn):
msg = 'Certificate with arn {0} not found in account {1}'.format(arn, DEFAULT_ACCOUNT_ID)
msg = "Certificate with arn {0} not found in account {1}".format(
arn, DEFAULT_ACCOUNT_ID
)
return AWSResourceNotFoundException(msg)
def _get_arn_from_idempotency_token(self, token):
@ -298,17 +370,20 @@ class AWSCertificateManagerBackend(BaseBackend):
"""
now = datetime.datetime.now()
if token in self._idempotency_tokens:
if self._idempotency_tokens[token]['expires'] < now:
if self._idempotency_tokens[token]["expires"] < now:
# Token has expired, new request
del self._idempotency_tokens[token]
return None
else:
return self._idempotency_tokens[token]['arn']
return self._idempotency_tokens[token]["arn"]
return None
def _set_idempotency_token_arn(self, token, arn):
self._idempotency_tokens[token] = {'arn': arn, 'expires': datetime.datetime.now() + datetime.timedelta(hours=1)}
self._idempotency_tokens[token] = {
"arn": arn,
"expires": datetime.datetime.now() + datetime.timedelta(hours=1),
}
def import_cert(self, certificate, private_key, chain=None, arn=None):
if arn is not None:
@ -316,7 +391,9 @@ class AWSCertificateManagerBackend(BaseBackend):
raise self._arn_not_found(arn)
else:
# Will reuse provided ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region, arn=arn)
bundle = CertBundle(
certificate, private_key, chain=chain, region=region, arn=arn
)
else:
# Will generate a random ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region)
@ -325,7 +402,7 @@ class AWSCertificateManagerBackend(BaseBackend):
return bundle.arn
def get_certificates_list(self):
def get_certificates_list(self, statuses):
"""
Get list of certificates
@ -333,7 +410,9 @@ class AWSCertificateManagerBackend(BaseBackend):
:rtype: list of CertBundle
"""
for arn in self._certificates.keys():
yield self.get_certificate(arn)
cert = self.get_certificate(arn)
if not statuses or cert.status in statuses:
yield cert
def get_certificate(self, arn):
if arn not in self._certificates:
@ -349,13 +428,21 @@ class AWSCertificateManagerBackend(BaseBackend):
del self._certificates[arn]
def request_certificate(self, domain_name, domain_validation_options, idempotency_token, subject_alt_names):
def request_certificate(
self,
domain_name,
domain_validation_options,
idempotency_token,
subject_alt_names,
):
if idempotency_token is not None:
arn = self._get_arn_from_idempotency_token(idempotency_token)
if arn is not None:
return arn
cert = CertBundle.generate_cert(domain_name, subject_alt_names)
cert = CertBundle.generate_cert(
domain_name, region=self.region, sans=subject_alt_names
)
if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert
@ -367,8 +454,8 @@ class AWSCertificateManagerBackend(BaseBackend):
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag['Key']
value = tag.get('Value', None)
key = tag["Key"]
value = tag.get("Value", None)
cert_bundle.tags[key] = value
def remove_tags_from_certificate(self, arn, tags):
@ -376,8 +463,8 @@ class AWSCertificateManagerBackend(BaseBackend):
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag['Key']
value = tag.get('Value', None)
key = tag["Key"]
value = tag.get("Value", None)
try:
# If value isnt provided, just delete key

View File

@ -7,7 +7,6 @@ from .models import acm_backends, AWSError, AWSValidationException
class AWSCertificateManagerResponse(BaseResponse):
@property
def acm_backend(self):
"""
@ -29,40 +28,49 @@ class AWSCertificateManagerResponse(BaseResponse):
return self.request_params.get(param, default)
def add_tags_to_certificate(self):
arn = self._get_param('CertificateArn')
tags = self._get_param('Tags')
arn = self._get_param("CertificateArn")
tags = self._get_param("Tags")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
self.acm_backend.add_tags_to_certificate(arn, tags)
except AWSError as err:
return err.response()
return ''
return ""
def delete_certificate(self):
arn = self._get_param('CertificateArn')
arn = self._get_param("CertificateArn")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
self.acm_backend.delete_certificate(arn)
except AWSError as err:
return err.response()
return ''
return ""
def describe_certificate(self):
arn = self._get_param('CertificateArn')
arn = self._get_param("CertificateArn")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
@ -72,11 +80,14 @@ class AWSCertificateManagerResponse(BaseResponse):
return json.dumps(cert_bundle.describe())
def get_certificate(self):
arn = self._get_param('CertificateArn')
arn = self._get_param("CertificateArn")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
@ -84,8 +95,8 @@ class AWSCertificateManagerResponse(BaseResponse):
return err.response()
result = {
'Certificate': cert_bundle.cert.decode(),
'CertificateChain': cert_bundle.chain.decode()
"Certificate": cert_bundle.cert.decode(),
"CertificateChain": cert_bundle.chain.decode(),
}
return json.dumps(result)
@ -102,104 +113,129 @@ class AWSCertificateManagerResponse(BaseResponse):
:return: str(JSON) for response
"""
certificate = self._get_param('Certificate')
private_key = self._get_param('PrivateKey')
chain = self._get_param('CertificateChain') # Optional
current_arn = self._get_param('CertificateArn') # Optional
certificate = self._get_param("Certificate")
private_key = self._get_param("PrivateKey")
chain = self._get_param("CertificateChain") # Optional
current_arn = self._get_param("CertificateArn") # Optional
# Simple parameter decoding. Rather do it here as its a data transport decision not part of the
# actual data
try:
certificate = base64.standard_b64decode(certificate)
except Exception:
return AWSValidationException('The certificate is not PEM-encoded or is not valid.').response()
return AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
).response()
try:
private_key = base64.standard_b64decode(private_key)
except Exception:
return AWSValidationException('The private key is not PEM-encoded or is not valid.').response()
return AWSValidationException(
"The private key is not PEM-encoded or is not valid."
).response()
if chain is not None:
try:
chain = base64.standard_b64decode(chain)
except Exception:
return AWSValidationException('The certificate chain is not PEM-encoded or is not valid.').response()
return AWSValidationException(
"The certificate chain is not PEM-encoded or is not valid."
).response()
try:
arn = self.acm_backend.import_cert(certificate, private_key, chain=chain, arn=current_arn)
arn = self.acm_backend.import_cert(
certificate, private_key, chain=chain, arn=current_arn
)
except AWSError as err:
return err.response()
return json.dumps({'CertificateArn': arn})
return json.dumps({"CertificateArn": arn})
def list_certificates(self):
certs = []
statuses = self._get_param("CertificateStatuses")
for cert_bundle in self.acm_backend.get_certificates_list(statuses):
certs.append(
{
"CertificateArn": cert_bundle.arn,
"DomainName": cert_bundle.common_name,
}
)
for cert_bundle in self.acm_backend.get_certificates_list():
certs.append({
'CertificateArn': cert_bundle.arn,
'DomainName': cert_bundle.common_name
})
result = {'CertificateSummaryList': certs}
result = {"CertificateSummaryList": certs}
return json.dumps(result)
def list_tags_for_certificate(self):
arn = self._get_param('CertificateArn')
arn = self._get_param("CertificateArn")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return {'__type': 'MissingParameter', 'message': msg}, dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return {"__type": "MissingParameter", "message": msg}, dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
result = {'Tags': []}
result = {"Tags": []}
# Tag "objects" can not contain the Value part
for key, value in cert_bundle.tags.items():
tag_dict = {'Key': key}
tag_dict = {"Key": key}
if value is not None:
tag_dict['Value'] = value
result['Tags'].append(tag_dict)
tag_dict["Value"] = value
result["Tags"].append(tag_dict)
return json.dumps(result)
def remove_tags_from_certificate(self):
arn = self._get_param('CertificateArn')
tags = self._get_param('Tags')
arn = self._get_param("CertificateArn")
tags = self._get_param("Tags")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
self.acm_backend.remove_tags_from_certificate(arn, tags)
except AWSError as err:
return err.response()
return ''
return ""
def request_certificate(self):
domain_name = self._get_param('DomainName')
domain_validation_options = self._get_param('DomainValidationOptions') # is ignored atm
idempotency_token = self._get_param('IdempotencyToken')
subject_alt_names = self._get_param('SubjectAlternativeNames')
domain_name = self._get_param("DomainName")
domain_validation_options = self._get_param(
"DomainValidationOptions"
) # is ignored atm
idempotency_token = self._get_param("IdempotencyToken")
subject_alt_names = self._get_param("SubjectAlternativeNames")
if subject_alt_names is not None and len(subject_alt_names) > 10:
# There is initial AWS limit of 10
msg = 'An ACM limit has been exceeded. Need to request SAN limit to be raised'
return json.dumps({'__type': 'LimitExceededException', 'message': msg}), dict(status=400)
msg = (
"An ACM limit has been exceeded. Need to request SAN limit to be raised"
)
return (
json.dumps({"__type": "LimitExceededException", "message": msg}),
dict(status=400),
)
try:
arn = self.acm_backend.request_certificate(domain_name, domain_validation_options, idempotency_token, subject_alt_names)
arn = self.acm_backend.request_certificate(
domain_name,
domain_validation_options,
idempotency_token,
subject_alt_names,
)
except AWSError as err:
return err.response()
return json.dumps({'CertificateArn': arn})
return json.dumps({"CertificateArn": arn})
def resend_validation_email(self):
arn = self._get_param('CertificateArn')
domain = self._get_param('Domain')
arn = self._get_param("CertificateArn")
domain = self._get_param("Domain")
# ValidationDomain not used yet.
# Contains domain which is equal to or a subset of Domain
# that AWS will send validation emails to
@ -207,18 +243,21 @@ class AWSCertificateManagerResponse(BaseResponse):
# validation_domain = self._get_param('ValidationDomain')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
if cert_bundle.common_name != domain:
msg = 'Parameter Domain does not match certificate domain'
_type = 'InvalidDomainValidationOptionsException'
return json.dumps({'__type': _type, 'message': msg}), dict(status=400)
msg = "Parameter Domain does not match certificate domain"
_type = "InvalidDomainValidationOptionsException"
return json.dumps({"__type": _type, "message": msg}), dict(status=400)
except AWSError as err:
return err.response()
return ''
return ""

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import AWSCertificateManagerResponse
url_bases = [
"https?://acm.(.+).amazonaws.com",
]
url_bases = ["https?://acm.(.+).amazonaws.com"]
url_paths = {
'{0}/$': AWSCertificateManagerResponse.dispatch,
}
url_paths = {"{0}/$": AWSCertificateManagerResponse.dispatch}

View File

@ -4,4 +4,6 @@ import uuid
def make_arn_for_certificate(account_id, region_name):
# Example
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(region_name, account_id, uuid.uuid4())
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(
region_name, account_id, uuid.uuid4()
)

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import apigateway_backends
from ..core.models import base_decorator, deprecated_base_decorator
apigateway_backend = apigateway_backends['us-east-1']
apigateway_backend = apigateway_backends["us-east-1"]
mock_apigateway = base_decorator(apigateway_backends)
mock_apigateway_deprecated = deprecated_base_decorator(apigateway_backends)

View File

@ -2,12 +2,105 @@ from __future__ import unicode_literals
from moto.core.exceptions import RESTError
class BadRequestException(RESTError):
pass
class AwsProxyNotAllowed(BadRequestException):
def __init__(self):
super(AwsProxyNotAllowed, self).__init__(
"BadRequestException",
"Integrations of type 'AWS_PROXY' currently only supports Lambda function and Firehose stream invocations.",
)
class CrossAccountNotAllowed(RESTError):
def __init__(self):
super(CrossAccountNotAllowed, self).__init__(
"AccessDeniedException", "Cross-account pass role is not allowed."
)
class RoleNotSpecified(BadRequestException):
def __init__(self):
super(RoleNotSpecified, self).__init__(
"BadRequestException", "Role ARN must be specified for AWS integrations"
)
class IntegrationMethodNotDefined(BadRequestException):
def __init__(self):
super(IntegrationMethodNotDefined, self).__init__(
"BadRequestException", "Enumeration value for HttpMethod must be non-empty"
)
class InvalidResourcePathException(BadRequestException):
def __init__(self):
super(InvalidResourcePathException, self).__init__(
"BadRequestException",
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace.",
)
class InvalidHttpEndpoint(BadRequestException):
def __init__(self):
super(InvalidHttpEndpoint, self).__init__(
"BadRequestException", "Invalid HTTP endpoint specified for URI"
)
class InvalidArn(BadRequestException):
def __init__(self):
super(InvalidArn, self).__init__(
"BadRequestException", "Invalid ARN specified in the request"
)
class InvalidIntegrationArn(BadRequestException):
def __init__(self):
super(InvalidIntegrationArn, self).__init__(
"BadRequestException", "AWS ARN for integration must contain path or action"
)
class InvalidRequestInput(BadRequestException):
def __init__(self):
super(InvalidRequestInput, self).__init__(
"BadRequestException", "Invalid request input"
)
class NoIntegrationDefined(BadRequestException):
def __init__(self):
super(NoIntegrationDefined, self).__init__(
"BadRequestException", "No integration defined for method"
)
class NoMethodDefined(BadRequestException):
def __init__(self):
super(NoMethodDefined, self).__init__(
"BadRequestException", "The REST API doesn't contain any methods"
)
class AuthorizerNotFoundException(RESTError):
code = 404
def __init__(self):
super(AuthorizerNotFoundException, self).__init__(
"NotFoundException", "Invalid Authorizer identifier specified"
)
class StageNotFoundException(RESTError):
code = 404
def __init__(self):
super(StageNotFoundException, self).__init__(
"NotFoundException", "Invalid stage identifier specified")
"NotFoundException", "Invalid stage identifier specified"
)
class ApiKeyNotFoundException(RESTError):
@ -15,4 +108,14 @@ class ApiKeyNotFoundException(RESTError):
def __init__(self):
super(ApiKeyNotFoundException, self).__init__(
"NotFoundException", "Invalid API Key identifier specified")
"NotFoundException", "Invalid API Key identifier specified"
)
class ApiKeyAlreadyExists(RESTError):
code = 409
def __init__(self):
super(ApiKeyAlreadyExists, self).__init__(
"ConflictException", "API Key already exists"
)

File diff suppressed because it is too large Load Diff

View File

@ -4,13 +4,30 @@ import json
from moto.core.responses import BaseResponse
from .models import apigateway_backends
from .exceptions import StageNotFoundException, ApiKeyNotFoundException
from .exceptions import (
ApiKeyNotFoundException,
BadRequestException,
CrossAccountNotAllowed,
AuthorizerNotFoundException,
StageNotFoundException,
ApiKeyAlreadyExists,
)
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
AUTHORIZER_TYPES = ["TOKEN", "REQUEST", "COGNITO_USER_POOLS"]
ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400):
return (
status,
self.response_headers,
json.dumps({"__type": type_, "message": message}),
)
def _get_param(self, key):
return json.loads(self.body).get(key)
return json.loads(self.body).get(key) if self.body else None
def _get_param_with_default_value(self, key, default):
jsonbody = json.loads(self.body)
@ -27,25 +44,61 @@ class APIGatewayResponse(BaseResponse):
def restapis(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if self.method == 'GET':
if self.method == "GET":
apis = self.backend.list_apis()
return 200, {}, json.dumps({"item": [
api.to_dict() for api in apis
]})
elif self.method == 'POST':
name = self._get_param('name')
description = self._get_param('description')
rest_api = self.backend.create_rest_api(name, description)
return 200, {}, json.dumps({"item": [api.to_dict() for api in apis]})
elif self.method == "POST":
name = self._get_param("name")
description = self._get_param("description")
api_key_source = self._get_param("apiKeySource")
endpoint_configuration = self._get_param("endpointConfiguration")
tags = self._get_param("tags")
# Param validation
if api_key_source and api_key_source not in API_KEY_SOURCES:
return self.error(
"ValidationException",
(
"1 validation error detected: "
"Value '{api_key_source}' at 'createRestApiInput.apiKeySource' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[AUTHORIZER, HEADER]"
).format(api_key_source=api_key_source),
)
if endpoint_configuration and "types" in endpoint_configuration:
invalid_types = list(
set(endpoint_configuration["types"])
- set(ENDPOINT_CONFIGURATION_TYPES)
)
if invalid_types:
return self.error(
"ValidationException",
(
"1 validation error detected: Value '{endpoint_type}' "
"at 'createRestApiInput.endpointConfiguration.types' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[PRIVATE, EDGE, REGIONAL]"
).format(endpoint_type=invalid_types[0]),
)
rest_api = self.backend.create_rest_api(
name,
description,
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
)
return 200, {}, json.dumps(rest_api.to_dict())
def restapis_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET':
if self.method == "GET":
rest_api = self.backend.get_rest_api(function_id)
return 200, {}, json.dumps(rest_api.to_dict())
elif self.method == 'DELETE':
elif self.method == "DELETE":
rest_api = self.backend.delete_rest_api(function_id)
return 200, {}, json.dumps(rest_api.to_dict())
@ -53,26 +106,34 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET':
if self.method == "GET":
resources = self.backend.list_resources(function_id)
return 200, {}, json.dumps({"item": [
resource.to_dict() for resource in resources
]})
return (
200,
{},
json.dumps({"item": [resource.to_dict() for resource in resources]}),
)
def resource_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
resource_id = self.path.split("/")[-1]
if self.method == 'GET':
resource = self.backend.get_resource(function_id, resource_id)
elif self.method == 'POST':
path_part = self._get_param("pathPart")
resource = self.backend.create_resource(
function_id, resource_id, path_part)
elif self.method == 'DELETE':
resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict())
try:
if self.method == "GET":
resource = self.backend.get_resource(function_id, resource_id)
elif self.method == "POST":
path_part = self._get_param("pathPart")
resource = self.backend.create_resource(
function_id, resource_id, path_part
)
elif self.method == "DELETE":
resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict())
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
def resource_methods(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -81,14 +142,19 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4]
method_type = url_path_parts[6]
if self.method == 'GET':
method = self.backend.get_method(
function_id, resource_id, method_type)
if self.method == "GET":
method = self.backend.get_method(function_id, resource_id, method_type)
return 200, {}, json.dumps(method)
elif self.method == 'PUT':
elif self.method == "PUT":
authorization_type = self._get_param("authorizationType")
api_key_required = self._get_param("apiKeyRequired")
method = self.backend.create_method(
function_id, resource_id, method_type, authorization_type)
function_id,
resource_id,
method_type,
authorization_type,
api_key_required,
)
return 200, {}, json.dumps(method)
def resource_method_responses(self, request, full_url, headers):
@ -99,37 +165,129 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6]
response_code = url_path_parts[8]
if self.method == 'GET':
if self.method == "GET":
method_response = self.backend.get_method_response(
function_id, resource_id, method_type, response_code)
elif self.method == 'PUT':
function_id, resource_id, method_type, response_code
)
elif self.method == "PUT":
method_response = self.backend.create_method_response(
function_id, resource_id, method_type, response_code)
elif self.method == 'DELETE':
function_id, resource_id, method_type, response_code
)
elif self.method == "DELETE":
method_response = self.backend.delete_method_response(
function_id, resource_id, method_type, response_code)
function_id, resource_id, method_type, response_code
)
return 200, {}, json.dumps(method_response)
def restapis_authorizers(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
restapi_id = url_path_parts[2]
if self.method == "POST":
name = self._get_param("name")
authorizer_type = self._get_param("type")
provider_arns = self._get_param_with_default_value("providerARNs", None)
auth_type = self._get_param_with_default_value("authType", None)
authorizer_uri = self._get_param_with_default_value("authorizerUri", None)
authorizer_credentials = self._get_param_with_default_value(
"authorizerCredentials", None
)
identity_source = self._get_param_with_default_value("identitySource", None)
identiy_validation_expression = self._get_param_with_default_value(
"identityValidationExpression", None
)
authorizer_result_ttl = self._get_param_with_default_value(
"authorizerResultTtlInSeconds", 300
)
# Param validation
if authorizer_type and authorizer_type not in AUTHORIZER_TYPES:
return self.error(
"ValidationException",
(
"1 validation error detected: "
"Value '{authorizer_type}' at 'createAuthorizerInput.type' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[TOKEN, REQUEST, COGNITO_USER_POOLS]"
).format(authorizer_type=authorizer_type),
)
authorizer_response = self.backend.create_authorizer(
restapi_id,
name,
authorizer_type,
provider_arns=provider_arns,
auth_type=auth_type,
authorizer_uri=authorizer_uri,
authorizer_credentials=authorizer_credentials,
identity_source=identity_source,
identiy_validation_expression=identiy_validation_expression,
authorizer_result_ttl=authorizer_result_ttl,
)
elif self.method == "GET":
authorizers = self.backend.get_authorizers(restapi_id)
return 200, {}, json.dumps({"item": authorizers})
return 200, {}, json.dumps(authorizer_response)
def authorizers(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
restapi_id = url_path_parts[2]
authorizer_id = url_path_parts[4]
if self.method == "GET":
try:
authorizer_response = self.backend.get_authorizer(
restapi_id, authorizer_id
)
except AuthorizerNotFoundException as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
authorizer_response = self.backend.update_authorizer(
restapi_id, authorizer_id, patch_operations
)
elif self.method == "DELETE":
self.backend.delete_authorizer(restapi_id, authorizer_id)
return 202, {}, "{}"
return 200, {}, json.dumps(authorizer_response)
def restapis_stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
function_id = url_path_parts[2]
if self.method == 'POST':
if self.method == "POST":
stage_name = self._get_param("stageName")
deployment_id = self._get_param("deploymentId")
stage_variables = self._get_param_with_default_value(
'variables', {})
description = self._get_param_with_default_value('description', '')
stage_variables = self._get_param_with_default_value("variables", {})
description = self._get_param_with_default_value("description", "")
cacheClusterEnabled = self._get_param_with_default_value(
'cacheClusterEnabled', False)
"cacheClusterEnabled", False
)
cacheClusterSize = self._get_param_with_default_value(
'cacheClusterSize', None)
"cacheClusterSize", None
)
stage_response = self.backend.create_stage(function_id, stage_name, deployment_id,
variables=stage_variables, description=description,
cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize)
elif self.method == 'GET':
stage_response = self.backend.create_stage(
function_id,
stage_name,
deployment_id,
variables=stage_variables,
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
)
elif self.method == "GET":
stages = self.backend.get_stages(function_id)
return 200, {}, json.dumps({"item": stages})
@ -141,16 +299,25 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2]
stage_name = url_path_parts[4]
if self.method == 'GET':
if self.method == "GET":
try:
stage_response = self.backend.get_stage(
function_id, stage_name)
stage_response = self.backend.get_stage(function_id, stage_name)
except StageNotFoundException as error:
return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type)
elif self.method == 'PATCH':
patch_operations = self._get_param('patchOperations')
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
stage_response = self.backend.update_stage(
function_id, stage_name, patch_operations)
function_id, stage_name, patch_operations
)
elif self.method == "DELETE":
self.backend.delete_stage(function_id, stage_name)
return 202, {}, "{}"
return 200, {}, json.dumps(stage_response)
def integrations(self, request, full_url, headers):
@ -160,19 +327,40 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4]
method_type = url_path_parts[6]
if self.method == 'GET':
integration_response = self.backend.get_integration(
function_id, resource_id, method_type)
elif self.method == 'PUT':
integration_type = self._get_param('type')
uri = self._get_param('uri')
request_templates = self._get_param('requestTemplates')
integration_response = self.backend.create_integration(
function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates)
elif self.method == 'DELETE':
integration_response = self.backend.delete_integration(
function_id, resource_id, method_type)
return 200, {}, json.dumps(integration_response)
try:
if self.method == "GET":
integration_response = self.backend.get_integration(
function_id, resource_id, method_type
)
elif self.method == "PUT":
integration_type = self._get_param("type")
uri = self._get_param("uri")
integration_http_method = self._get_param("httpMethod")
creds = self._get_param("credentials")
request_templates = self._get_param("requestTemplates")
integration_response = self.backend.create_integration(
function_id,
resource_id,
method_type,
integration_type,
uri,
credentials=creds,
integration_method=integration_http_method,
request_templates=request_templates,
)
elif self.method == "DELETE":
integration_response = self.backend.delete_integration(
function_id, resource_id, method_type
)
return 200, {}, json.dumps(integration_response)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
except CrossAccountNotAllowed as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#AccessDeniedException", e.message
)
def integration_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -182,36 +370,52 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6]
status_code = url_path_parts[9]
if self.method == 'GET':
integration_response = self.backend.get_integration_response(
function_id, resource_id, method_type, status_code
try:
if self.method == "GET":
integration_response = self.backend.get_integration_response(
function_id, resource_id, method_type, status_code
)
elif self.method == "PUT":
selection_pattern = self._get_param("selectionPattern")
response_templates = self._get_param("responseTemplates")
integration_response = self.backend.create_integration_response(
function_id,
resource_id,
method_type,
status_code,
selection_pattern,
response_templates,
)
elif self.method == "DELETE":
integration_response = self.backend.delete_integration_response(
function_id, resource_id, method_type, status_code
)
return 200, {}, json.dumps(integration_response)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
elif self.method == 'PUT':
selection_pattern = self._get_param("selectionPattern")
integration_response = self.backend.create_integration_response(
function_id, resource_id, method_type, status_code, selection_pattern
)
elif self.method == 'DELETE':
integration_response = self.backend.delete_integration_response(
function_id, resource_id, method_type, status_code
)
return 200, {}, json.dumps(integration_response)
def deployments(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET':
deployments = self.backend.get_deployments(function_id)
return 200, {}, json.dumps({"item": deployments})
elif self.method == 'POST':
name = self._get_param("stageName")
description = self._get_param_with_default_value("description", "")
stage_variables = self._get_param_with_default_value(
'variables', {})
deployment = self.backend.create_deployment(
function_id, name, description, stage_variables)
return 200, {}, json.dumps(deployment)
try:
if self.method == "GET":
deployments = self.backend.get_deployments(function_id)
return 200, {}, json.dumps({"item": deployments})
elif self.method == "POST":
name = self._get_param("stageName")
description = self._get_param_with_default_value("description", "")
stage_variables = self._get_param_with_default_value("variables", {})
deployment = self.backend.create_deployment(
function_id, name, description, stage_variables
)
return 200, {}, json.dumps(deployment)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
def individual_deployment(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -219,20 +423,28 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2]
deployment_id = url_path_parts[4]
if self.method == 'GET':
deployment = self.backend.get_deployment(
function_id, deployment_id)
elif self.method == 'DELETE':
deployment = self.backend.delete_deployment(
function_id, deployment_id)
if self.method == "GET":
deployment = self.backend.get_deployment(function_id, deployment_id)
elif self.method == "DELETE":
deployment = self.backend.delete_deployment(function_id, deployment_id)
return 200, {}, json.dumps(deployment)
def apikeys(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if self.method == 'POST':
apikey_response = self.backend.create_apikey(json.loads(self.body))
elif self.method == 'GET':
if self.method == "POST":
try:
apikey_response = self.backend.create_apikey(json.loads(self.body))
except ApiKeyAlreadyExists as error:
return (
error.code,
self.headers,
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "GET":
apikeys_response = self.backend.get_apikeys()
return 200, {}, json.dumps({"item": apikeys_response})
return 200, {}, json.dumps(apikey_response)
@ -243,18 +455,21 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/")
apikey = url_path_parts[2]
if self.method == 'GET':
if self.method == "GET":
apikey_response = self.backend.get_apikey(apikey)
elif self.method == 'DELETE':
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
apikey_response = self.backend.update_apikey(apikey, patch_operations)
elif self.method == "DELETE":
apikey_response = self.backend.delete_apikey(apikey)
return 200, {}, json.dumps(apikey_response)
def usage_plans(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if self.method == 'POST':
if self.method == "POST":
usage_plan_response = self.backend.create_usage_plan(json.loads(self.body))
elif self.method == 'GET':
elif self.method == "GET":
api_key_id = self.querystring.get("keyId", [None])[0]
usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id)
return 200, {}, json.dumps({"item": usage_plans_response})
@ -266,9 +481,9 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/")
usage_plan = url_path_parts[2]
if self.method == 'GET':
if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan(usage_plan)
elif self.method == 'DELETE':
elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan(usage_plan)
return 200, {}, json.dumps(usage_plan_response)
@ -278,13 +493,21 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/")
usage_plan_id = url_path_parts[2]
if self.method == 'POST':
if self.method == "POST":
try:
usage_plan_response = self.backend.create_usage_plan_key(usage_plan_id, json.loads(self.body))
usage_plan_response = self.backend.create_usage_plan_key(
usage_plan_id, json.loads(self.body)
)
except ApiKeyNotFoundException as error:
return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type)
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == 'GET':
elif self.method == "GET":
usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id)
return 200, {}, json.dumps({"item": usage_plans_response})
@ -297,8 +520,10 @@ class APIGatewayResponse(BaseResponse):
usage_plan_id = url_path_parts[2]
key_id = url_path_parts[4]
if self.method == 'GET':
if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id)
elif self.method == 'DELETE':
usage_plan_response = self.backend.delete_usage_plan_key(usage_plan_id, key_id)
elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan_key(
usage_plan_id, key_id
)
return 200, {}, json.dumps(usage_plan_response)

View File

@ -1,27 +1,27 @@
from __future__ import unicode_literals
from .responses import APIGatewayResponse
url_bases = [
"https?://apigateway.(.+).amazonaws.com"
]
url_bases = ["https?://apigateway.(.+).amazonaws.com"]
url_paths = {
'{0}/restapis$': APIGatewayResponse().restapis,
'{0}/restapis/(?P<function_id>[^/]+)/?$': APIGatewayResponse().restapis_individual,
'{0}/restapis/(?P<function_id>[^/]+)/resources$': APIGatewayResponse().resources,
'{0}/restapis/(?P<function_id>[^/]+)/stages$': APIGatewayResponse().restapis_stages,
'{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$': APIGatewayResponse().stages,
'{0}/restapis/(?P<function_id>[^/]+)/deployments$': APIGatewayResponse().deployments,
'{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$': APIGatewayResponse().individual_deployment,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$': APIGatewayResponse().resource_individual,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$': APIGatewayResponse().resource_methods,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$': APIGatewayResponse().resource_method_responses,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$': APIGatewayResponse().integrations,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$': APIGatewayResponse().integration_responses,
'{0}/apikeys$': APIGatewayResponse().apikeys,
'{0}/apikeys/(?P<apikey>[^/]+)': APIGatewayResponse().apikey_individual,
'{0}/usageplans$': APIGatewayResponse().usage_plans,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$': APIGatewayResponse().usage_plan_individual,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$': APIGatewayResponse().usage_plan_keys,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$': APIGatewayResponse().usage_plan_key_individual,
"{0}/restapis$": APIGatewayResponse().restapis,
"{0}/restapis/(?P<function_id>[^/]+)/?$": APIGatewayResponse().restapis_individual,
"{0}/restapis/(?P<function_id>[^/]+)/resources$": APIGatewayResponse().resources,
"{0}/restapis/(?P<function_id>[^/]+)/authorizers$": APIGatewayResponse().restapis_authorizers,
"{0}/restapis/(?P<function_id>[^/]+)/authorizers/(?P<authorizer_id>[^/]+)/?$": APIGatewayResponse().authorizers,
"{0}/restapis/(?P<function_id>[^/]+)/stages$": APIGatewayResponse().restapis_stages,
"{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$": APIGatewayResponse().stages,
"{0}/restapis/(?P<function_id>[^/]+)/deployments$": APIGatewayResponse().deployments,
"{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$": APIGatewayResponse().individual_deployment,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$": APIGatewayResponse().resource_individual,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$": APIGatewayResponse().resource_methods,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$": APIGatewayResponse().resource_method_responses,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$": APIGatewayResponse().integrations,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$": APIGatewayResponse().integration_responses,
"{0}/apikeys$": APIGatewayResponse().apikeys,
"{0}/apikeys/(?P<apikey>[^/]+)": APIGatewayResponse().apikey_individual,
"{0}/usageplans$": APIGatewayResponse().usage_plans,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$": APIGatewayResponse().usage_plan_individual,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$": APIGatewayResponse().usage_plan_keys,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual,
}

View File

@ -7,4 +7,4 @@ import string
def create_id():
size = 10
chars = list(range(10)) + list(string.ascii_lowercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size))
return "".join(six.text_type(random.choice(chars)) for x in range(size))

7
moto/athena/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from __future__ import unicode_literals
from .models import athena_backends
from ..core.models import base_decorator, deprecated_base_decorator
athena_backend = athena_backends["us-east-1"]
mock_athena = base_decorator(athena_backends)
mock_athena_deprecated = deprecated_base_decorator(athena_backends)

19
moto/athena/exceptions.py Normal file
View File

@ -0,0 +1,19 @@
from __future__ import unicode_literals
import json
from werkzeug.exceptions import BadRequest
class AthenaClientError(BadRequest):
def __init__(self, code, message):
super(AthenaClientError, self).__init__()
self.description = json.dumps(
{
"Error": {
"Code": code,
"Message": message,
"Type": "InvalidRequestException",
},
"RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1",
}
)

86
moto/athena/models.py Normal file
View File

@ -0,0 +1,86 @@
from __future__ import unicode_literals
import time
from boto3 import Session
from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID
class TaggableResourceMixin(object):
# This mixing was copied from Redshift when initially implementing
# Athena. TBD if it's worth the overhead.
def __init__(self, region_name, resource_name, tags):
self.region = region_name
self.resource_name = resource_name
self.tags = tags or []
@property
def arn(self):
return "arn:aws:athena:{region}:{account_id}:{resource_name}".format(
region=self.region, account_id=ACCOUNT_ID, resource_name=self.resource_name
)
def create_tags(self, tags):
new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
self.tags.extend(tags)
return self.tags
def delete_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
return self.tags
class WorkGroup(TaggableResourceMixin, BaseModel):
resource_type = "workgroup"
state = "ENABLED"
def __init__(self, athena_backend, name, configuration, description, tags):
self.region_name = athena_backend.region_name
super(WorkGroup, self).__init__(
self.region_name, "workgroup/{}".format(name), tags
)
self.athena_backend = athena_backend
self.name = name
self.description = description
self.configuration = configuration
class AthenaBackend(BaseBackend):
region_name = None
def __init__(self, region_name=None):
if region_name is not None:
self.region_name = region_name
self.work_groups = {}
def create_work_group(self, name, configuration, description, tags):
if name in self.work_groups:
return None
work_group = WorkGroup(self, name, configuration, description, tags)
self.work_groups[name] = work_group
return work_group
def list_work_groups(self):
return [
{
"Name": wg.name,
"State": wg.state,
"Description": wg.description,
"CreationTime": time.time(),
}
for wg in self.work_groups.values()
]
athena_backends = {}
for region in Session().get_available_regions("athena"):
athena_backends[region] = AthenaBackend(region)
for region in Session().get_available_regions("athena", partition_name="aws-us-gov"):
athena_backends[region] = AthenaBackend(region)
for region in Session().get_available_regions("athena", partition_name="aws-cn"):
athena_backends[region] = AthenaBackend(region)

41
moto/athena/responses.py Normal file
View File

@ -0,0 +1,41 @@
import json
from moto.core.responses import BaseResponse
from .models import athena_backends
class AthenaResponse(BaseResponse):
@property
def athena_backend(self):
return athena_backends[self.region]
def create_work_group(self):
name = self._get_param("Name")
description = self._get_param("Description")
configuration = self._get_param("Configuration")
tags = self._get_param("Tags")
work_group = self.athena_backend.create_work_group(
name, configuration, description, tags
)
if not work_group:
return (
json.dumps(
{
"__type": "InvalidRequestException",
"Message": "WorkGroup already exists",
}
),
dict(status=400),
)
return json.dumps(
{
"CreateWorkGroupResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a"
}
}
}
)
def list_work_groups(self):
return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()})

6
moto/athena/urls.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .responses import AthenaResponse
url_bases = ["https?://athena.(.+).amazonaws.com"]
url_paths = {"{0}/$": AthenaResponse.dispatch}

1
moto/athena/utils.py Normal file
View File

@ -0,0 +1 @@
from __future__ import unicode_literals

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import autoscaling_backends
from ..core.models import base_decorator, deprecated_base_decorator
autoscaling_backend = autoscaling_backends['us-east-1']
autoscaling_backend = autoscaling_backends["us-east-1"]
mock_autoscaling = base_decorator(autoscaling_backends)
mock_autoscaling_deprecated = deprecated_base_decorator(autoscaling_backends)

View File

@ -12,13 +12,12 @@ class ResourceContentionError(RESTError):
def __init__(self):
super(ResourceContentionError, self).__init__(
"ResourceContentionError",
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).")
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).",
)
class InvalidInstanceError(AutoscalingClientError):
def __init__(self, instance_id):
super(InvalidInstanceError, self).__init__(
"ValidationError",
"Instance [{0}] is invalid."
.format(instance_id))
"ValidationError", "Instance [{0}] is invalid.".format(instance_id)
)

View File

@ -12,7 +12,9 @@ from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
from moto.elb.exceptions import LoadBalancerNotFoundError
from .exceptions import (
AutoscalingClientError, ResourceContentionError, InvalidInstanceError
AutoscalingClientError,
ResourceContentionError,
InvalidInstanceError,
)
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
@ -22,8 +24,13 @@ ASG_NAME_TAG = "aws:autoscaling:groupName"
class InstanceState(object):
def __init__(self, instance, lifecycle_state="InService",
health_status="Healthy", protected_from_scale_in=False):
def __init__(
self,
instance,
lifecycle_state="InService",
health_status="Healthy",
protected_from_scale_in=False,
):
self.instance = instance
self.lifecycle_state = lifecycle_state
self.health_status = health_status
@ -31,8 +38,16 @@ class InstanceState(object):
class FakeScalingPolicy(BaseModel):
def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment,
cooldown, autoscaling_backend):
def __init__(
self,
name,
policy_type,
adjustment_type,
as_name,
scaling_adjustment,
cooldown,
autoscaling_backend,
):
self.name = name
self.policy_type = policy_type
self.adjustment_type = adjustment_type
@ -45,21 +60,38 @@ class FakeScalingPolicy(BaseModel):
self.autoscaling_backend = autoscaling_backend
def execute(self):
if self.adjustment_type == 'ExactCapacity':
if self.adjustment_type == "ExactCapacity":
self.autoscaling_backend.set_desired_capacity(
self.as_name, self.scaling_adjustment)
elif self.adjustment_type == 'ChangeInCapacity':
self.as_name, self.scaling_adjustment
)
elif self.adjustment_type == "ChangeInCapacity":
self.autoscaling_backend.change_capacity(
self.as_name, self.scaling_adjustment)
elif self.adjustment_type == 'PercentChangeInCapacity':
self.as_name, self.scaling_adjustment
)
elif self.adjustment_type == "PercentChangeInCapacity":
self.autoscaling_backend.change_capacity_percent(
self.as_name, self.scaling_adjustment)
self.as_name, self.scaling_adjustment
)
class FakeLaunchConfiguration(BaseModel):
def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data,
instance_type, instance_monitoring, instance_profile_name,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict):
def __init__(
self,
name,
image_id,
key_name,
ramdisk_id,
kernel_id,
security_groups,
user_data,
instance_type,
instance_monitoring,
instance_profile_name,
spot_price,
ebs_optimized,
associate_public_ip_address,
block_device_mapping_dict,
):
self.name = name
self.image_id = image_id
self.key_name = key_name
@ -80,8 +112,8 @@ class FakeLaunchConfiguration(BaseModel):
config = backend.create_launch_configuration(
name=name,
image_id=instance.image_id,
kernel_id='',
ramdisk_id='',
kernel_id="",
ramdisk_id="",
key_name=instance.key_name,
security_groups=instance.security_groups,
user_data=instance.user_data,
@ -91,13 +123,15 @@ class FakeLaunchConfiguration(BaseModel):
spot_price=None,
ebs_optimized=instance.ebs_optimized,
associate_public_ip_address=instance.associate_public_ip,
block_device_mappings=instance.block_device_mapping
block_device_mappings=instance.block_device_mapping,
)
return config
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
instance_profile_name = properties.get("IamInstanceProfile")
@ -115,20 +149,26 @@ class FakeLaunchConfiguration(BaseModel):
instance_profile_name=instance_profile_name,
spot_price=properties.get("SpotPrice"),
ebs_optimized=properties.get("EbsOptimized"),
associate_public_ip_address=properties.get(
"AssociatePublicIpAddress"),
block_device_mappings=properties.get("BlockDeviceMapping.member")
associate_public_ip_address=properties.get("AssociatePublicIpAddress"),
block_device_mappings=properties.get("BlockDeviceMapping.member"),
)
return config
@classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name)
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name)
original_resource.name, cloudformation_json, region_name
)
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
@classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = autoscaling_backends[region_name]
try:
backend.delete_launch_configuration(resource_name)
@ -153,34 +193,49 @@ class FakeLaunchConfiguration(BaseModel):
@property
def instance_monitoring_enabled(self):
if self.instance_monitoring:
return 'true'
return 'false'
return "true"
return "false"
def _parse_block_device_mappings(self):
block_device_map = BlockDeviceMapping()
for mapping in self.block_device_mapping_dict:
block_type = BlockDeviceType()
mount_point = mapping.get('device_name')
if 'ephemeral' in mapping.get('virtual_name', ''):
block_type.ephemeral_name = mapping.get('virtual_name')
mount_point = mapping.get("device_name")
if "ephemeral" in mapping.get("virtual_name", ""):
block_type.ephemeral_name = mapping.get("virtual_name")
else:
block_type.volume_type = mapping.get('ebs._volume_type')
block_type.snapshot_id = mapping.get('ebs._snapshot_id')
block_type.volume_type = mapping.get("ebs._volume_type")
block_type.snapshot_id = mapping.get("ebs._snapshot_id")
block_type.delete_on_termination = mapping.get(
'ebs._delete_on_termination')
block_type.size = mapping.get('ebs._volume_size')
block_type.iops = mapping.get('ebs._iops')
"ebs._delete_on_termination"
)
block_type.size = mapping.get("ebs._volume_size")
block_type.iops = mapping.get("ebs._iops")
block_device_map[mount_point] = block_type
return block_device_map
class FakeAutoScalingGroup(BaseModel):
def __init__(self, name, availability_zones, desired_capacity, max_size,
min_size, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, health_check_type,
load_balancers, target_group_arns, placement_group, termination_policies,
autoscaling_backend, tags,
new_instances_protected_from_scale_in=False):
def __init__(
self,
name,
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
load_balancers,
target_group_arns,
placement_group,
termination_policies,
autoscaling_backend,
tags,
new_instances_protected_from_scale_in=False,
):
self.autoscaling_backend = autoscaling_backend
self.name = name
@ -190,17 +245,22 @@ class FakeAutoScalingGroup(BaseModel):
self.min_size = min_size
self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name]
launch_config_name
]
self.launch_config_name = launch_config_name
self.default_cooldown = default_cooldown if default_cooldown else DEFAULT_COOLDOWN
self.default_cooldown = (
default_cooldown if default_cooldown else DEFAULT_COOLDOWN
)
self.health_check_period = health_check_period
self.health_check_type = health_check_type if health_check_type else "EC2"
self.load_balancers = load_balancers
self.target_group_arns = target_group_arns
self.placement_group = placement_group
self.termination_policies = termination_policies
self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in
self.new_instances_protected_from_scale_in = (
new_instances_protected_from_scale_in
)
self.suspended_processes = []
self.instance_states = []
@ -215,8 +275,10 @@ class FakeAutoScalingGroup(BaseModel):
if vpc_zone_identifier:
# extract azs for vpcs
subnet_ids = vpc_zone_identifier.split(',')
subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(subnet_ids=subnet_ids)
subnet_ids = vpc_zone_identifier.split(",")
subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(
subnet_ids=subnet_ids
)
vpc_zones = [subnet.availability_zone for subnet in subnets]
if availability_zones and set(availability_zones) != set(vpc_zones):
@ -229,7 +291,7 @@ class FakeAutoScalingGroup(BaseModel):
if not update:
raise AutoscalingClientError(
"ValidationError",
"At least one Availability Zone or VPC Subnet is required."
"At least one Availability Zone or VPC Subnet is required.",
)
return
@ -237,8 +299,10 @@ class FakeAutoScalingGroup(BaseModel):
self.vpc_zone_identifier = vpc_zone_identifier
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
launch_config_name = properties.get("LaunchConfigurationName")
load_balancer_names = properties.get("LoadBalancerNames", [])
@ -253,7 +317,8 @@ class FakeAutoScalingGroup(BaseModel):
min_size=properties.get("MinSize"),
launch_config_name=launch_config_name,
vpc_zone_identifier=(
','.join(properties.get("VPCZoneIdentifier", [])) or None),
",".join(properties.get("VPCZoneIdentifier", [])) or None
),
default_cooldown=properties.get("Cooldown"),
health_check_period=properties.get("HealthCheckGracePeriod"),
health_check_type=properties.get("HealthCheckType"),
@ -263,18 +328,26 @@ class FakeAutoScalingGroup(BaseModel):
termination_policies=properties.get("TerminationPolicies", []),
tags=properties.get("Tags", []),
new_instances_protected_from_scale_in=properties.get(
"NewInstancesProtectedFromScaleIn", False)
"NewInstancesProtectedFromScaleIn", False
),
)
return group
@classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name)
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name)
original_resource.name, cloudformation_json, region_name
)
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
@classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = autoscaling_backends[region_name]
try:
backend.delete_auto_scaling_group(resource_name)
@ -289,11 +362,21 @@ class FakeAutoScalingGroup(BaseModel):
def physical_resource_id(self):
return self.name
def update(self, availability_zones, desired_capacity, max_size, min_size,
launch_config_name, vpc_zone_identifier, default_cooldown,
health_check_period, health_check_type,
placement_group, termination_policies,
new_instances_protected_from_scale_in=None):
def update(
self,
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=None,
):
self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier, update=True)
if max_size is not None:
@ -309,14 +392,17 @@ class FakeAutoScalingGroup(BaseModel):
if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name]
launch_config_name
]
self.launch_config_name = launch_config_name
if health_check_period is not None:
self.health_check_period = health_check_period
if health_check_type is not None:
self.health_check_type = health_check_type
if new_instances_protected_from_scale_in is not None:
self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in
self.new_instances_protected_from_scale_in = (
new_instances_protected_from_scale_in
)
if desired_capacity is not None:
self.set_desired_capacity(desired_capacity)
@ -342,25 +428,30 @@ class FakeAutoScalingGroup(BaseModel):
# Need to remove some instances
count_to_remove = curr_instance_count - self.desired_capacity
instances_to_remove = [ # only remove unprotected
state for state in self.instance_states
state
for state in self.instance_states
if not state.protected_from_scale_in
][:count_to_remove]
if instances_to_remove: # just in case not instances to remove
instance_ids_to_remove = [
instance.instance.id for instance in instances_to_remove]
instance.instance.id for instance in instances_to_remove
]
self.autoscaling_backend.ec2_backend.terminate_instances(
instance_ids_to_remove)
self.instance_states = list(set(self.instance_states) - set(instances_to_remove))
instance_ids_to_remove
)
self.instance_states = list(
set(self.instance_states) - set(instances_to_remove)
)
def get_propagated_tags(self):
propagated_tags = {}
for tag in self.tags:
# boto uses 'propagate_at_launch
# boto3 and cloudformation use PropagateAtLaunch
if 'propagate_at_launch' in tag and tag['propagate_at_launch'] == 'true':
propagated_tags[tag['key']] = tag['value']
if 'PropagateAtLaunch' in tag and tag['PropagateAtLaunch']:
propagated_tags[tag['Key']] = tag['Value']
if "propagate_at_launch" in tag and tag["propagate_at_launch"] == "true":
propagated_tags[tag["key"]] = tag["value"]
if "PropagateAtLaunch" in tag and tag["PropagateAtLaunch"]:
propagated_tags[tag["Key"]] = tag["Value"]
return propagated_tags
def replace_autoscaling_group_instances(self, count_needed, propagated_tags):
@ -371,15 +462,17 @@ class FakeAutoScalingGroup(BaseModel):
self.launch_config.user_data,
self.launch_config.security_groups,
instance_type=self.launch_config.instance_type,
tags={'instance': propagated_tags},
tags={"instance": propagated_tags},
placement=random.choice(self.availability_zones),
)
for instance in reservation.instances:
instance.autoscaling_group = self
self.instance_states.append(InstanceState(
instance,
protected_from_scale_in=self.new_instances_protected_from_scale_in,
))
self.instance_states.append(
InstanceState(
instance,
protected_from_scale_in=self.new_instances_protected_from_scale_in,
)
)
def append_target_groups(self, target_group_arns):
append = [x for x in target_group_arns if x not in self.target_group_arns]
@ -402,10 +495,23 @@ class AutoScalingBackend(BaseBackend):
self.__dict__ = {}
self.__init__(ec2_backend, elb_backend, elbv2_backend)
def create_launch_configuration(self, name, image_id, key_name, kernel_id, ramdisk_id,
security_groups, user_data, instance_type,
instance_monitoring, instance_profile_name,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mappings):
def create_launch_configuration(
self,
name,
image_id,
key_name,
kernel_id,
ramdisk_id,
security_groups,
user_data,
instance_type,
instance_monitoring,
instance_profile_name,
spot_price,
ebs_optimized,
associate_public_ip_address,
block_device_mappings,
):
launch_configuration = FakeLaunchConfiguration(
name=name,
image_id=image_id,
@ -428,23 +534,37 @@ class AutoScalingBackend(BaseBackend):
def describe_launch_configurations(self, names):
configurations = self.launch_configurations.values()
if names:
return [configuration for configuration in configurations if configuration.name in names]
return [
configuration
for configuration in configurations
if configuration.name in names
]
else:
return list(configurations)
def delete_launch_configuration(self, launch_configuration_name):
self.launch_configurations.pop(launch_configuration_name, None)
def create_auto_scaling_group(self, name, availability_zones,
desired_capacity, max_size, min_size,
launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period,
health_check_type, load_balancers,
target_group_arns, placement_group,
termination_policies, tags,
new_instances_protected_from_scale_in=False,
instance_id=None):
def create_auto_scaling_group(
self,
name,
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
load_balancers,
target_group_arns,
placement_group,
termination_policies,
tags,
new_instances_protected_from_scale_in=False,
instance_id=None,
):
def make_int(value):
return int(value) if value is not None else value
@ -460,7 +580,9 @@ class AutoScalingBackend(BaseBackend):
try:
instance = self.ec2_backend.get_instance(instance_id)
launch_config_name = name
FakeLaunchConfiguration.create_from_instance(launch_config_name, instance, self)
FakeLaunchConfiguration.create_from_instance(
launch_config_name, instance, self
)
except InvalidInstanceIdError:
raise InvalidInstanceError(instance_id)
@ -489,19 +611,37 @@ class AutoScalingBackend(BaseBackend):
self.update_attached_target_groups(group.name)
return group
def update_auto_scaling_group(self, name, availability_zones,
desired_capacity, max_size, min_size,
launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period,
health_check_type, placement_group,
termination_policies,
new_instances_protected_from_scale_in=None):
def update_auto_scaling_group(
self,
name,
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=None,
):
group = self.autoscaling_groups[name]
group.update(availability_zones, desired_capacity, max_size,
min_size, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, health_check_type,
placement_group, termination_policies,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in)
group.update(
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in,
)
return group
def describe_auto_scaling_groups(self, names):
@ -537,32 +677,48 @@ class AutoScalingBackend(BaseBackend):
for x in instance_ids
]
for instance in new_instances:
self.ec2_backend.create_tags([instance.instance.id], {ASG_NAME_TAG: group.name})
self.ec2_backend.create_tags(
[instance.instance.id], {ASG_NAME_TAG: group.name}
)
group.instance_states.extend(new_instances)
self.update_attached_elbs(group.name)
def set_instance_health(self, instance_id, health_status, should_respect_grace_period):
def set_instance_health(
self, instance_id, health_status, should_respect_grace_period
):
instance = self.ec2_backend.get_instance(instance_id)
instance_state = next(instance_state for group in self.autoscaling_groups.values()
for instance_state in group.instance_states if instance_state.instance.id == instance.id)
instance_state = next(
instance_state
for group in self.autoscaling_groups.values()
for instance_state in group.instance_states
if instance_state.instance.id == instance.id
)
instance_state.health_status = health_status
def detach_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name]
original_size = len(group.instance_states)
detached_instances = [x for x in group.instance_states if x.instance.id in instance_ids]
detached_instances = [
x for x in group.instance_states if x.instance.id in instance_ids
]
for instance in detached_instances:
self.ec2_backend.delete_tags([instance.instance.id], {ASG_NAME_TAG: group.name})
self.ec2_backend.delete_tags(
[instance.instance.id], {ASG_NAME_TAG: group.name}
)
new_instance_state = [x for x in group.instance_states if x.instance.id not in instance_ids]
new_instance_state = [
x for x in group.instance_states if x.instance.id not in instance_ids
]
group.instance_states = new_instance_state
if should_decrement:
group.desired_capacity = original_size - len(instance_ids)
else:
count_needed = len(instance_ids)
group.replace_autoscaling_group_instances(count_needed, group.get_propagated_tags())
group.replace_autoscaling_group_instances(
count_needed, group.get_propagated_tags()
)
self.update_attached_elbs(group_name)
return detached_instances
@ -593,19 +749,32 @@ class AutoScalingBackend(BaseBackend):
desired_capacity = int(desired_capacity)
self.set_desired_capacity(group_name, desired_capacity)
def create_autoscaling_policy(self, name, policy_type, adjustment_type, as_name,
scaling_adjustment, cooldown):
policy = FakeScalingPolicy(name, policy_type, adjustment_type, as_name,
scaling_adjustment, cooldown, self)
def create_autoscaling_policy(
self, name, policy_type, adjustment_type, as_name, scaling_adjustment, cooldown
):
policy = FakeScalingPolicy(
name,
policy_type,
adjustment_type,
as_name,
scaling_adjustment,
cooldown,
self,
)
self.policies[name] = policy
return policy
def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None):
return [policy for policy in self.policies.values()
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and
(not policy_names or policy.name in policy_names) and
(not policy_types or policy.policy_type in policy_types)]
def describe_policies(
self, autoscaling_group_name=None, policy_names=None, policy_types=None
):
return [
policy
for policy in self.policies.values()
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name)
and (not policy_names or policy.name in policy_names)
and (not policy_types or policy.policy_type in policy_types)
]
def delete_policy(self, group_name):
self.policies.pop(group_name, None)
@ -616,16 +785,14 @@ class AutoScalingBackend(BaseBackend):
def update_attached_elbs(self, group_name):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(
state.instance.id for state in group.instance_states)
group_instance_ids = set(state.instance.id for state in group.instance_states)
# skip this if group.load_balancers is empty
# otherwise elb_backend.describe_load_balancers returns all available load balancers
if not group.load_balancers:
return
try:
elbs = self.elb_backend.describe_load_balancers(
names=group.load_balancers)
elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
except LoadBalancerNotFoundError:
# ELBs can be deleted before their autoscaling group
return
@ -633,14 +800,15 @@ class AutoScalingBackend(BaseBackend):
for elb in elbs:
elb_instace_ids = set(elb.instance_ids)
self.elb_backend.register_instances(
elb.name, group_instance_ids - elb_instace_ids)
elb.name, group_instance_ids - elb_instace_ids
)
self.elb_backend.deregister_instances(
elb.name, elb_instace_ids - group_instance_ids)
elb.name, elb_instace_ids - group_instance_ids
)
def update_attached_target_groups(self, group_name):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(
state.instance.id for state in group.instance_states)
group_instance_ids = set(state.instance.id for state in group.instance_states)
# no action necessary if target_group_arns is empty
if not group.target_group_arns:
@ -649,10 +817,13 @@ class AutoScalingBackend(BaseBackend):
target_groups = self.elbv2_backend.describe_target_groups(
target_group_arns=group.target_group_arns,
load_balancer_arn=None,
names=None)
names=None,
)
for target_group in target_groups:
asg_targets = [{'id': x, 'port': target_group.port} for x in group_instance_ids]
asg_targets = [
{"id": x, "port": target_group.port} for x in group_instance_ids
]
self.elbv2_backend.register_targets(target_group.arn, (asg_targets))
def create_or_update_tags(self, tags):
@ -670,7 +841,7 @@ class AutoScalingBackend(BaseBackend):
new_tags.append(old_tag)
# if key was never in old_tag's add it (create tag)
if not any(new_tag['key'] == tag['key'] for new_tag in new_tags):
if not any(new_tag["key"] == tag["key"] for new_tag in new_tags):
new_tags.append(tag)
group.tags = new_tags
@ -678,7 +849,8 @@ class AutoScalingBackend(BaseBackend):
def attach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name]
group.load_balancers.extend(
[x for x in load_balancer_names if x not in group.load_balancers])
[x for x in load_balancer_names if x not in group.load_balancers]
)
self.update_attached_elbs(group_name)
def describe_load_balancers(self, group_name):
@ -686,13 +858,13 @@ class AutoScalingBackend(BaseBackend):
def detach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(
state.instance.id for state in group.instance_states)
group_instance_ids = set(state.instance.id for state in group.instance_states)
elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
for elb in elbs:
self.elb_backend.deregister_instances(
elb.name, group_instance_ids)
group.load_balancers = [x for x in group.load_balancers if x not in load_balancer_names]
self.elb_backend.deregister_instances(elb.name, group_instance_ids)
group.load_balancers = [
x for x in group.load_balancers if x not in load_balancer_names
]
def attach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name]
@ -704,36 +876,51 @@ class AutoScalingBackend(BaseBackend):
def detach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name]
group.target_group_arns = [x for x in group.target_group_arns if x not in target_group_arns]
group.target_group_arns = [
x for x in group.target_group_arns if x not in target_group_arns
]
for target_group in target_group_arns:
asg_targets = [{'id': x.instance.id} for x in group.instance_states]
asg_targets = [{"id": x.instance.id} for x in group.instance_states]
self.elbv2_backend.deregister_targets(target_group, (asg_targets))
def suspend_processes(self, group_name, scaling_processes):
group = self.autoscaling_groups[group_name]
group.suspended_processes = scaling_processes or []
def set_instance_protection(self, group_name, instance_ids, protected_from_scale_in):
def set_instance_protection(
self, group_name, instance_ids, protected_from_scale_in
):
group = self.autoscaling_groups[group_name]
protected_instances = [
x for x in group.instance_states if x.instance.id in instance_ids]
x for x in group.instance_states if x.instance.id in instance_ids
]
for instance in protected_instances:
instance.protected_from_scale_in = protected_from_scale_in
def notify_terminate_instances(self, instance_ids):
for autoscaling_group_name, autoscaling_group in self.autoscaling_groups.items():
for (
autoscaling_group_name,
autoscaling_group,
) in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states)
autoscaling_group.instance_states = list(filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states = list(
filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states,
)
)
difference = original_instance_count - len(
autoscaling_group.instance_states
))
difference = original_instance_count - len(autoscaling_group.instance_states)
)
if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(difference, autoscaling_group.get_propagated_tags())
autoscaling_group.replace_autoscaling_group_instances(
difference, autoscaling_group.get_propagated_tags()
)
self.update_attached_elbs(autoscaling_group_name)
autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items():
autoscaling_backends[region] = AutoScalingBackend(
ec2_backend, elb_backends[region], elbv2_backends[region])
ec2_backend, elb_backends[region], elbv2_backends[region]
)

View File

@ -6,88 +6,88 @@ from .models import autoscaling_backends
class AutoScalingResponse(BaseResponse):
@property
def autoscaling_backend(self):
return autoscaling_backends[self.region]
def create_launch_configuration(self):
instance_monitoring_string = self._get_param(
'InstanceMonitoring.Enabled')
if instance_monitoring_string == 'true':
instance_monitoring_string = self._get_param("InstanceMonitoring.Enabled")
if instance_monitoring_string == "true":
instance_monitoring = True
else:
instance_monitoring = False
self.autoscaling_backend.create_launch_configuration(
name=self._get_param('LaunchConfigurationName'),
image_id=self._get_param('ImageId'),
key_name=self._get_param('KeyName'),
ramdisk_id=self._get_param('RamdiskId'),
kernel_id=self._get_param('KernelId'),
security_groups=self._get_multi_param('SecurityGroups.member'),
user_data=self._get_param('UserData'),
instance_type=self._get_param('InstanceType'),
name=self._get_param("LaunchConfigurationName"),
image_id=self._get_param("ImageId"),
key_name=self._get_param("KeyName"),
ramdisk_id=self._get_param("RamdiskId"),
kernel_id=self._get_param("KernelId"),
security_groups=self._get_multi_param("SecurityGroups.member"),
user_data=self._get_param("UserData"),
instance_type=self._get_param("InstanceType"),
instance_monitoring=instance_monitoring,
instance_profile_name=self._get_param('IamInstanceProfile'),
spot_price=self._get_param('SpotPrice'),
ebs_optimized=self._get_param('EbsOptimized'),
associate_public_ip_address=self._get_param(
"AssociatePublicIpAddress"),
block_device_mappings=self._get_list_prefix(
'BlockDeviceMappings.member')
instance_profile_name=self._get_param("IamInstanceProfile"),
spot_price=self._get_param("SpotPrice"),
ebs_optimized=self._get_param("EbsOptimized"),
associate_public_ip_address=self._get_param("AssociatePublicIpAddress"),
block_device_mappings=self._get_list_prefix("BlockDeviceMappings.member"),
)
template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render()
def describe_launch_configurations(self):
names = self._get_multi_param('LaunchConfigurationNames.member')
all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(names)
marker = self._get_param('NextToken')
names = self._get_multi_param("LaunchConfigurationNames.member")
all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(
names
)
marker = self._get_param("NextToken")
all_names = [lc.name for lc in all_launch_configurations]
if marker:
start = all_names.index(marker) + 1
else:
start = 0
max_records = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[start:start + max_records]
max_records = self._get_int_param(
"MaxRecords", 50
) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[
start : start + max_records
]
next_token = None
if len(all_launch_configurations) > start + max_records:
next_token = launch_configurations_resp[-1].name
template = self.response_template(
DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE)
return template.render(launch_configurations=launch_configurations_resp, next_token=next_token)
template = self.response_template(DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE)
return template.render(
launch_configurations=launch_configurations_resp, next_token=next_token
)
def delete_launch_configuration(self):
launch_configurations_name = self.querystring.get(
'LaunchConfigurationName')[0]
self.autoscaling_backend.delete_launch_configuration(
launch_configurations_name)
launch_configurations_name = self.querystring.get("LaunchConfigurationName")[0]
self.autoscaling_backend.delete_launch_configuration(launch_configurations_name)
template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render()
def create_auto_scaling_group(self):
self.autoscaling_backend.create_auto_scaling_group(
name=self._get_param('AutoScalingGroupName'),
availability_zones=self._get_multi_param(
'AvailabilityZones.member'),
desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'),
instance_id=self._get_param('InstanceId'),
launch_config_name=self._get_param('LaunchConfigurationName'),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'),
default_cooldown=self._get_int_param('DefaultCooldown'),
health_check_period=self._get_int_param('HealthCheckGracePeriod'),
health_check_type=self._get_param('HealthCheckType'),
load_balancers=self._get_multi_param('LoadBalancerNames.member'),
target_group_arns=self._get_multi_param('TargetGroupARNs.member'),
placement_group=self._get_param('PlacementGroup'),
termination_policies=self._get_multi_param(
'TerminationPolicies.member'),
tags=self._get_list_prefix('Tags.member'),
name=self._get_param("AutoScalingGroupName"),
availability_zones=self._get_multi_param("AvailabilityZones.member"),
desired_capacity=self._get_int_param("DesiredCapacity"),
max_size=self._get_int_param("MaxSize"),
min_size=self._get_int_param("MinSize"),
instance_id=self._get_param("InstanceId"),
launch_config_name=self._get_param("LaunchConfigurationName"),
vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
default_cooldown=self._get_int_param("DefaultCooldown"),
health_check_period=self._get_int_param("HealthCheckGracePeriod"),
health_check_type=self._get_param("HealthCheckType"),
load_balancers=self._get_multi_param("LoadBalancerNames.member"),
target_group_arns=self._get_multi_param("TargetGroupARNs.member"),
placement_group=self._get_param("PlacementGroup"),
termination_policies=self._get_multi_param("TerminationPolicies.member"),
tags=self._get_list_prefix("Tags.member"),
new_instances_protected_from_scale_in=self._get_bool_param(
'NewInstancesProtectedFromScaleIn', False)
"NewInstancesProtectedFromScaleIn", False
),
)
template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render()
@ -95,68 +95,73 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32
@amzn_request_id
def attach_instances(self):
group_name = self._get_param('AutoScalingGroupName')
instance_ids = self._get_multi_param('InstanceIds.member')
self.autoscaling_backend.attach_instances(
group_name, instance_ids)
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
self.autoscaling_backend.attach_instances(group_name, instance_ids)
template = self.response_template(ATTACH_INSTANCES_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def set_instance_health(self):
instance_id = self._get_param('InstanceId')
instance_id = self._get_param("InstanceId")
health_status = self._get_param("HealthStatus")
if health_status not in ['Healthy', 'Unhealthy']:
raise ValueError('Valid instance health states are: [Healthy, Unhealthy]')
if health_status not in ["Healthy", "Unhealthy"]:
raise ValueError("Valid instance health states are: [Healthy, Unhealthy]")
should_respect_grace_period = self._get_param("ShouldRespectGracePeriod")
self.autoscaling_backend.set_instance_health(instance_id, health_status, should_respect_grace_period)
self.autoscaling_backend.set_instance_health(
instance_id, health_status, should_respect_grace_period
)
template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def detach_instances(self):
group_name = self._get_param('AutoScalingGroupName')
instance_ids = self._get_multi_param('InstanceIds.member')
should_decrement_string = self._get_param('ShouldDecrementDesiredCapacity')
if should_decrement_string == 'true':
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
if should_decrement_string == "true":
should_decrement = True
else:
should_decrement = False
detached_instances = self.autoscaling_backend.detach_instances(
group_name, instance_ids, should_decrement)
group_name, instance_ids, should_decrement
)
template = self.response_template(DETACH_INSTANCES_TEMPLATE)
return template.render(detached_instances=detached_instances)
@amz_crc32
@amzn_request_id
def attach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
target_group_arns = self._get_multi_param('TargetGroupARNs.member')
group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self._get_multi_param("TargetGroupARNs.member")
self.autoscaling_backend.attach_load_balancer_target_groups(
group_name, target_group_arns)
group_name, target_group_arns
)
template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def describe_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups(
group_name)
group_name
)
template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS)
return template.render(target_group_arns=target_group_arns)
@amz_crc32
@amzn_request_id
def detach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
target_group_arns = self._get_multi_param('TargetGroupARNs.member')
group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self._get_multi_param("TargetGroupARNs.member")
self.autoscaling_backend.detach_load_balancer_target_groups(
group_name, target_group_arns)
group_name, target_group_arns
)
template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render()
@ -172,7 +177,7 @@ class AutoScalingResponse(BaseResponse):
max_records = self._get_int_param("MaxRecords", 50)
if max_records > 100:
raise ValueError
groups = all_groups[start:start + max_records]
groups = all_groups[start : start + max_records]
next_token = None
if max_records and len(all_groups) > start + max_records:
next_token = groups[-1].name
@ -181,42 +186,40 @@ class AutoScalingResponse(BaseResponse):
def update_auto_scaling_group(self):
self.autoscaling_backend.update_auto_scaling_group(
name=self._get_param('AutoScalingGroupName'),
availability_zones=self._get_multi_param(
'AvailabilityZones.member'),
desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'),
launch_config_name=self._get_param('LaunchConfigurationName'),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'),
default_cooldown=self._get_int_param('DefaultCooldown'),
health_check_period=self._get_int_param('HealthCheckGracePeriod'),
health_check_type=self._get_param('HealthCheckType'),
placement_group=self._get_param('PlacementGroup'),
termination_policies=self._get_multi_param(
'TerminationPolicies.member'),
name=self._get_param("AutoScalingGroupName"),
availability_zones=self._get_multi_param("AvailabilityZones.member"),
desired_capacity=self._get_int_param("DesiredCapacity"),
max_size=self._get_int_param("MaxSize"),
min_size=self._get_int_param("MinSize"),
launch_config_name=self._get_param("LaunchConfigurationName"),
vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
default_cooldown=self._get_int_param("DefaultCooldown"),
health_check_period=self._get_int_param("HealthCheckGracePeriod"),
health_check_type=self._get_param("HealthCheckType"),
placement_group=self._get_param("PlacementGroup"),
termination_policies=self._get_multi_param("TerminationPolicies.member"),
new_instances_protected_from_scale_in=self._get_bool_param(
'NewInstancesProtectedFromScaleIn', None)
"NewInstancesProtectedFromScaleIn", None
),
)
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render()
def delete_auto_scaling_group(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
self.autoscaling_backend.delete_auto_scaling_group(group_name)
template = self.response_template(DELETE_AUTOSCALING_GROUP_TEMPLATE)
return template.render()
def set_desired_capacity(self):
group_name = self._get_param('AutoScalingGroupName')
desired_capacity = self._get_int_param('DesiredCapacity')
self.autoscaling_backend.set_desired_capacity(
group_name, desired_capacity)
group_name = self._get_param("AutoScalingGroupName")
desired_capacity = self._get_int_param("DesiredCapacity")
self.autoscaling_backend.set_desired_capacity(group_name, desired_capacity)
template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE)
return template.render()
def create_or_update_tags(self):
tags = self._get_list_prefix('Tags.member')
tags = self._get_list_prefix("Tags.member")
self.autoscaling_backend.create_or_update_tags(tags)
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
@ -224,38 +227,38 @@ class AutoScalingResponse(BaseResponse):
def describe_auto_scaling_instances(self):
instance_states = self.autoscaling_backend.describe_auto_scaling_instances()
template = self.response_template(
DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
return template.render(instance_states=instance_states)
def put_scaling_policy(self):
policy = self.autoscaling_backend.create_autoscaling_policy(
name=self._get_param('PolicyName'),
policy_type=self._get_param('PolicyType'),
adjustment_type=self._get_param('AdjustmentType'),
as_name=self._get_param('AutoScalingGroupName'),
scaling_adjustment=self._get_int_param('ScalingAdjustment'),
cooldown=self._get_int_param('Cooldown'),
name=self._get_param("PolicyName"),
policy_type=self._get_param("PolicyType"),
adjustment_type=self._get_param("AdjustmentType"),
as_name=self._get_param("AutoScalingGroupName"),
scaling_adjustment=self._get_int_param("ScalingAdjustment"),
cooldown=self._get_int_param("Cooldown"),
)
template = self.response_template(CREATE_SCALING_POLICY_TEMPLATE)
return template.render(policy=policy)
def describe_policies(self):
policies = self.autoscaling_backend.describe_policies(
autoscaling_group_name=self._get_param('AutoScalingGroupName'),
policy_names=self._get_multi_param('PolicyNames.member'),
policy_types=self._get_multi_param('PolicyTypes.member'))
autoscaling_group_name=self._get_param("AutoScalingGroupName"),
policy_names=self._get_multi_param("PolicyNames.member"),
policy_types=self._get_multi_param("PolicyTypes.member"),
)
template = self.response_template(DESCRIBE_SCALING_POLICIES_TEMPLATE)
return template.render(policies=policies)
def delete_policy(self):
group_name = self._get_param('PolicyName')
group_name = self._get_param("PolicyName")
self.autoscaling_backend.delete_policy(group_name)
template = self.response_template(DELETE_POLICY_TEMPLATE)
return template.render()
def execute_policy(self):
group_name = self._get_param('PolicyName')
group_name = self._get_param("PolicyName")
self.autoscaling_backend.execute_policy(group_name)
template = self.response_template(EXECUTE_POLICY_TEMPLATE)
return template.render()
@ -263,17 +266,16 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32
@amzn_request_id
def attach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.attach_load_balancers(
group_name, load_balancer_names)
self.autoscaling_backend.attach_load_balancers(group_name, load_balancer_names)
template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def describe_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
load_balancers = self.autoscaling_backend.describe_load_balancers(group_name)
template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE)
return template.render(load_balancers=load_balancers)
@ -281,26 +283,28 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32
@amzn_request_id
def detach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.detach_load_balancers(
group_name, load_balancer_names)
self.autoscaling_backend.detach_load_balancers(group_name, load_balancer_names)
template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE)
return template.render()
def suspend_processes(self):
autoscaling_group_name = self._get_param('AutoScalingGroupName')
scaling_processes = self._get_multi_param('ScalingProcesses.member')
self.autoscaling_backend.suspend_processes(autoscaling_group_name, scaling_processes)
autoscaling_group_name = self._get_param("AutoScalingGroupName")
scaling_processes = self._get_multi_param("ScalingProcesses.member")
self.autoscaling_backend.suspend_processes(
autoscaling_group_name, scaling_processes
)
template = self.response_template(SUSPEND_PROCESSES_TEMPLATE)
return template.render()
def set_instance_protection(self):
group_name = self._get_param('AutoScalingGroupName')
instance_ids = self._get_multi_param('InstanceIds.member')
protected_from_scale_in = self._get_bool_param('ProtectedFromScaleIn')
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
protected_from_scale_in = self._get_bool_param("ProtectedFromScaleIn")
self.autoscaling_backend.set_instance_protection(
group_name, instance_ids, protected_from_scale_in)
group_name, instance_ids, protected_from_scale_in
)
template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE)
return template.render()

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import AutoScalingResponse
url_bases = [
"https?://autoscaling.(.+).amazonaws.com",
]
url_bases = ["https?://autoscaling.(.+).amazonaws.com"]
url_paths = {
'{0}/$': AutoScalingResponse.dispatch,
}
url_paths = {"{0}/$": AutoScalingResponse.dispatch}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import lambda_backends
from ..core.models import base_decorator, deprecated_base_decorator
lambda_backend = lambda_backends['us-east-1']
lambda_backend = lambda_backends["us-east-1"]
mock_lambda = base_decorator(lambda_backends)
mock_lambda_deprecated = deprecated_base_decorator(lambda_backends)

View File

@ -0,0 +1,41 @@
from botocore.client import ClientError
from moto.core.exceptions import JsonRESTError
class LambdaClientError(ClientError):
def __init__(self, error, message):
error_response = {"Error": {"Code": error, "Message": message}}
super(LambdaClientError, self).__init__(error_response, None)
class CrossAccountNotAllowed(LambdaClientError):
def __init__(self):
super(CrossAccountNotAllowed, self).__init__(
"AccessDeniedException", "Cross-account pass role is not allowed."
)
class InvalidParameterValueException(LambdaClientError):
def __init__(self, message):
super(InvalidParameterValueException, self).__init__(
"InvalidParameterValueException", message
)
class InvalidRoleFormat(LambdaClientError):
pattern = r"arn:(aws[a-zA-Z-]*)?:iam::(\d{12}):role/?[a-zA-Z_0-9+=,.@\-_/]+"
def __init__(self, role):
message = "1 validation error detected: Value '{0}' at 'role' failed to satisfy constraint: Member must satisfy regular expression pattern: {1}".format(
role, InvalidRoleFormat.pattern
)
super(InvalidRoleFormat, self).__init__("ValidationException", message)
class PreconditionFailedException(JsonRESTError):
code = 412
def __init__(self, message):
super(PreconditionFailedException, self).__init__(
"PreconditionFailedException", message
)

File diff suppressed because it is too large Load Diff

134
moto/awslambda/policy.py Normal file
View File

@ -0,0 +1,134 @@
from __future__ import unicode_literals
import json
import uuid
from six import string_types
from moto.awslambda.exceptions import PreconditionFailedException
class Policy:
def __init__(self, parent):
self.revision = str(uuid.uuid4())
self.statements = []
self.parent = parent
def wire_format(self):
p = self.get_policy()
p["Policy"] = json.dumps(p["Policy"])
return json.dumps(p)
def get_policy(self):
return {
"Policy": {
"Version": "2012-10-17",
"Id": "default",
"Statement": self.statements,
},
"RevisionId": self.revision,
}
# adds the raw JSON statement to the policy
def add_statement(self, raw):
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
self.statements.append(policy.statements[0])
self.revision = str(uuid.uuid4())
# removes the statement that matches 'sid' from the policy
def del_statement(self, sid, revision=""):
if len(revision) > 0 and self.revision != revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
for statement in self.statements:
if "Sid" in statement and statement["Sid"] == sid:
self.statements.remove(statement)
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj):
# import pydevd
# pydevd.settrace("localhost", port=5678)
policy = Policy(self.parent)
policy.revision = obj.get("RevisionId", "")
# set some default values if these keys are not set
self.ensure_set(obj, "Effect", "Allow")
self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST")
self.ensure_set(obj, "StatementId", str(uuid.uuid4()))
# transform field names and values
self.transform_property(obj, "StatementId", "Sid", self.nop_formatter)
self.transform_property(obj, "Principal", "Principal", self.principal_formatter)
self.transform_property(
obj, "SourceArn", "SourceArn", self.source_arn_formatter
)
self.transform_property(
obj, "SourceAccount", "SourceAccount", self.source_account_formatter
)
# remove RevisionId and EventSourceToken if they are set
self.remove_if_set(obj, ["RevisionId", "EventSourceToken"])
# merge conditional statements into a single map under the Condition key
self.condition_merge(obj)
# append resulting statement to policy.statements
policy.statements.append(obj)
return policy
def nop_formatter(self, obj):
return obj
def ensure_set(self, obj, key, value):
if key not in obj:
obj[key] = value
def principal_formatter(self, obj):
if isinstance(obj, string_types):
if obj.endswith(".amazonaws.com"):
return {"Service": obj}
if obj.endswith(":root"):
return {"AWS": obj}
return obj
def source_account_formatter(self, obj):
return {"StringEquals": {"AWS:SourceAccount": obj}}
def source_arn_formatter(self, obj):
return {"ArnLike": {"AWS:SourceArn": obj}}
def transform_property(self, obj, old_name, new_name, formatter):
if old_name in obj:
obj[new_name] = formatter(obj[old_name])
if new_name != old_name:
del obj[old_name]
def remove_if_set(self, obj, keys):
for key in keys:
if key in obj:
del obj[key]
def condition_merge(self, obj):
if "SourceArn" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceArn"])
del obj["SourceArn"]
if "SourceAccount" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceAccount"])
del obj["SourceAccount"]

View File

@ -32,32 +32,57 @@ class LambdaResponse(BaseResponse):
def root(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
return self._list_functions(request, full_url, headers)
elif request.method == 'POST':
elif request.method == "POST":
return self._create_function(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def event_source_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":
querystring = self.querystring
event_source_arn = querystring.get("EventSourceArn", [None])[0]
function_name = querystring.get("FunctionName", [None])[0]
return self._list_event_source_mappings(event_source_arn, function_name)
elif request.method == "POST":
return self._create_event_source_mapping(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def event_source_mapping(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
path = request.path if hasattr(request, "path") else path_url(request.url)
uuid = path.split("/")[-1]
if request.method == "GET":
return self._get_event_source_mapping(uuid)
elif request.method == "PUT":
return self._update_event_source_mapping(uuid)
elif request.method == "DELETE":
return self._delete_event_source_mapping(uuid)
else:
raise ValueError("Cannot handle request")
def function(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
return self._get_function(request, full_url, headers)
elif request.method == 'DELETE':
elif request.method == "DELETE":
return self._delete_function(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def versions(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
# This is ListVersionByFunction
path = request.path if hasattr(request, 'path') else path_url(request.url)
function_name = path.split('/')[-2]
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2]
return self._list_versions_by_function(function_name)
elif request.method == 'POST':
elif request.method == "POST":
return self._publish_function(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
@ -66,7 +91,7 @@ class LambdaResponse(BaseResponse):
@amzn_request_id
def invoke(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'POST':
if request.method == "POST":
return self._invoke(request, full_url)
else:
raise ValueError("Cannot handle request")
@ -75,110 +100,176 @@ class LambdaResponse(BaseResponse):
@amzn_request_id
def invoke_async(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'POST':
if request.method == "POST":
return self._invoke_async(request, full_url)
else:
raise ValueError("Cannot handle request")
def tag(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
return self._list_tags(request, full_url)
elif request.method == 'POST':
elif request.method == "POST":
return self._tag_resource(request, full_url)
elif request.method == 'DELETE':
elif request.method == "DELETE":
return self._untag_resource(request, full_url)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def policy(self, request, full_url, headers):
if request.method == 'GET':
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._get_policy(request, full_url, headers)
if request.method == 'POST':
elif request.method == "POST":
return self._add_policy(request, full_url, headers)
elif request.method == "DELETE":
return self._del_policy(request, full_url, headers, self.querystring)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self._put_configuration(request)
else:
raise ValueError("Cannot handle request")
def code(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self._put_code()
else:
raise ValueError("Cannot handle request")
def _add_policy(self, request, full_url, headers):
path = request.path if hasattr(request, 'path') else path_url(request.url)
function_name = path.split('/')[-2]
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name):
policy = request.body.decode('utf8')
self.lambda_backend.add_policy(function_name, policy)
return 200, {}, json.dumps(dict(Statement=policy))
statement = self.body
self.lambda_backend.add_policy_statement(function_name, statement)
return 200, {}, json.dumps({"Statement": statement})
else:
return 404, {}, "{}"
def _get_policy(self, request, full_url, headers):
path = request.path if hasattr(request, 'path') else path_url(request.url)
function_name = path.split('/')[-2]
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name):
lambda_function = self.lambda_backend.get_function(function_name)
return 200, {}, json.dumps(dict(Policy="{\"Statement\":[" + lambda_function.policy + "]}"))
out = self.lambda_backend.get_policy_wire_format(function_name)
return 200, {}, out
else:
return 404, {}, "{}"
def _del_policy(self, request, full_url, headers, querystring):
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-3]
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
)
return 204, {}, "{}"
else:
return 404, {}, "{}"
def _invoke(self, request, full_url):
response_headers = {}
function_name = self.path.rsplit('/', 2)[-2]
qualifier = self._get_param('qualifier')
# URL Decode in case it's a ARN:
function_name = unquote(self.path.rsplit("/", 2)[-2])
qualifier = self._get_param("qualifier")
fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
payload = fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(len(payload))
return 202, response_headers, payload
response_header, payload = self.lambda_backend.invoke(
function_name, qualifier, self.body, self.headers, response_headers
)
if payload:
if request.headers["X-Amz-Invocation-Type"] == "Event":
status_code = 202
elif request.headers["X-Amz-Invocation-Type"] == "DryRun":
status_code = 204
else:
status_code = 200
return status_code, response_headers, payload
else:
return 404, response_headers, "{}"
def _invoke_async(self, request, full_url):
response_headers = {}
function_name = self.path.rsplit('/', 3)[-3]
function_name = self.path.rsplit("/", 3)[-3]
fn = self.lambda_backend.get_function(function_name, None)
if fn:
payload = fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(len(payload))
response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload
else:
return 404, response_headers, "{}"
def _list_functions(self, request, full_url, headers):
result = {
'Functions': []
}
result = {"Functions": []}
for fn in self.lambda_backend.list_functions():
json_data = fn.get_configuration()
json_data['Version'] = '$LATEST'
result['Functions'].append(json_data)
json_data["Version"] = "$LATEST"
result["Functions"].append(json_data)
return 200, {}, json.dumps(result)
def _list_versions_by_function(self, function_name):
result = {
'Versions': []
}
result = {"Versions": []}
functions = self.lambda_backend.list_versions_by_function(function_name)
if functions:
for fn in functions:
json_data = fn.get_configuration()
result['Versions'].append(json_data)
result["Versions"].append(json_data)
return 200, {}, json.dumps(result)
def _create_function(self, request, full_url, headers):
try:
fn = self.lambda_backend.create_function(self.json_body)
except ValueError as e:
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}})
fn = self.lambda_backend.create_function(self.json_body)
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _create_event_source_mapping(self, request, full_url, headers):
fn = self.lambda_backend.create_event_source_mapping(self.json_body)
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _list_event_source_mappings(self, event_source_arn, function_name):
esms = self.lambda_backend.list_event_source_mappings(
event_source_arn, function_name
)
result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]}
return 200, {}, json.dumps(result)
def _get_event_source_mapping(self, uuid):
result = self.lambda_backend.get_event_source_mapping(uuid)
if result:
return 200, {}, json.dumps(result.get_configuration())
else:
config = fn.get_configuration()
return 201, {}, json.dumps(config)
return 404, {}, "{}"
def _update_event_source_mapping(self, uuid):
result = self.lambda_backend.update_event_source_mapping(uuid, self.json_body)
if result:
return 202, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"
def _delete_event_source_mapping(self, uuid):
esm = self.lambda_backend.delete_event_source_mapping(uuid)
if esm:
json_result = esm.get_configuration()
json_result.update({"State": "Deleting"})
return 202, {}, json.dumps(json_result)
else:
return 404, {}, "{}"
def _publish_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 2)[-2]
function_name = self.path.rsplit("/", 2)[-2]
fn = self.lambda_backend.publish_function(function_name)
if fn:
@ -188,8 +279,8 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}"
def _delete_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 1)[-1]
qualifier = self._get_param('Qualifier', None)
function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param("Qualifier", None)
if self.lambda_backend.delete_function(function_name, qualifier):
return 204, {}, ""
@ -197,20 +288,20 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}"
def _get_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 1)[-1]
qualifier = self._get_param('Qualifier', None)
function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param("Qualifier", None)
fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
code = fn.get_code()
if qualifier is None or qualifier == '$LATEST':
code['Configuration']['Version'] = '$LATEST'
if qualifier == '$LATEST':
code['Configuration']['FunctionArn'] += ':$LATEST'
if qualifier is None or qualifier == "$LATEST":
code["Configuration"]["Version"] = "$LATEST"
if qualifier == "$LATEST":
code["Configuration"]["FunctionArn"] += ":$LATEST"
return 200, {}, json.dumps(code)
else:
return 404, {}, "{}"
return 404, {"x-amzn-ErrorType": "ResourceNotFoundException"}, "{}"
def _get_aws_region(self, full_url):
region = self.region_regex.search(full_url)
@ -220,27 +311,51 @@ class LambdaResponse(BaseResponse):
return self.default_region
def _list_tags(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1])
function_arn = unquote(self.path.rsplit("/", 1)[-1])
fn = self.lambda_backend.get_function_by_arn(function_arn)
if fn:
return 200, {}, json.dumps({'Tags': fn.tags})
return 200, {}, json.dumps({"Tags": fn.tags})
else:
return 404, {}, "{}"
def _tag_resource(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1])
function_arn = unquote(self.path.rsplit("/", 1)[-1])
if self.lambda_backend.tag_resource(function_arn, self.json_body['Tags']):
if self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]):
return 200, {}, "{}"
else:
return 404, {}, "{}"
def _untag_resource(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1])
tag_keys = self.querystring['tagKeys']
function_arn = unquote(self.path.rsplit("/", 1)[-1])
tag_keys = self.querystring["tagKeys"]
if self.lambda_backend.untag_resource(function_arn, tag_keys):
return 204, {}, "{}"
else:
return 404, {}, "{}"
def _put_configuration(self, request):
function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param("Qualifier", None)
resp = self.lambda_backend.update_function_configuration(
function_name, qualifier, body=self.json_body
)
if resp:
return 200, {}, json.dumps(resp)
else:
return 404, {}, "{}"
def _put_code(self):
function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param("Qualifier", None)
resp = self.lambda_backend.update_function_code(
function_name, qualifier, body=self.json_body
)
if resp:
return 200, {}, json.dumps(resp)
else:
return 404, {}, "{}"

View File

@ -1,18 +1,22 @@
from __future__ import unicode_literals
from .responses import LambdaResponse
url_bases = [
"https?://lambda.(.+).amazonaws.com",
]
url_bases = ["https?://lambda.(.+).amazonaws.com"]
response = LambdaResponse()
url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$': response.versions,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$': response.policy
r"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$": response.event_source_mapping,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<resource_arn>.+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/(?P<statement_id>[\w_-]+)$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,
}

View File

@ -1,20 +1,20 @@
from collections import namedtuple
ARN = namedtuple('ARN', ['region', 'account', 'function_name', 'version'])
ARN = namedtuple("ARN", ["region", "account", "function_name", "version"])
def make_function_arn(region, account, name):
return 'arn:aws:lambda:{0}:{1}:function:{2}'.format(region, account, name)
return "arn:aws:lambda:{0}:{1}:function:{2}".format(region, account, name)
def make_function_ver_arn(region, account, name, version='1'):
def make_function_ver_arn(region, account, name, version="1"):
arn = make_function_arn(region, account, name)
return '{0}:{1}'.format(arn, version)
return "{0}:{1}".format(arn, version)
def split_function_arn(arn):
arn = arn.replace('arn:aws:lambda:')
arn = arn.replace("arn:aws:lambda:")
region, account, _, name, version = arn.split(':')
region, account, _, name, version = arn.split(":")
return ARN(region, account, name, version)

View File

@ -2,18 +2,25 @@ from __future__ import unicode_literals
from moto.acm import acm_backends
from moto.apigateway import apigateway_backends
from moto.athena import athena_backends
from moto.autoscaling import autoscaling_backends
from moto.awslambda import lambda_backends
from moto.batch import batch_backends
from moto.cloudformation import cloudformation_backends
from moto.cloudwatch import cloudwatch_backends
from moto.codecommit import codecommit_backends
from moto.codepipeline import codepipeline_backends
from moto.cognitoidentity import cognitoidentity_backends
from moto.cognitoidp import cognitoidp_backends
from moto.config import config_backends
from moto.core import moto_api_backends
from moto.datapipeline import datapipeline_backends
from moto.datasync import datasync_backends
from moto.dynamodb import dynamodb_backends
from moto.dynamodb2 import dynamodb_backends2
from moto.dynamodbstreams import dynamodbstreams_backends
from moto.ec2 import ec2_backends
from moto.ec2_instance_connect import ec2_instance_connect_backends
from moto.ecr import ecr_backends
from moto.ecs import ecs_backends
from moto.elb import elb_backends
@ -24,6 +31,8 @@ from moto.glacier import glacier_backends
from moto.glue import glue_backends
from moto.iam import iam_backends
from moto.instance_metadata import instance_metadata_backends
from moto.iot import iot_backends
from moto.iotdata import iotdata_backends
from moto.kinesis import kinesis_backends
from moto.kms import kms_backends
from moto.logs import logs_backends
@ -33,72 +42,75 @@ from moto.polly import polly_backends
from moto.rds2 import rds2_backends
from moto.redshift import redshift_backends
from moto.resourcegroups import resourcegroups_backends
from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends
from moto.route53 import route53_backends
from moto.s3 import s3_backends
from moto.ses import ses_backends
from moto.secretsmanager import secretsmanager_backends
from moto.ses import ses_backends
from moto.sns import sns_backends
from moto.sqs import sqs_backends
from moto.ssm import ssm_backends
from moto.stepfunctions import stepfunction_backends
from moto.sts import sts_backends
from moto.swf import swf_backends
from moto.xray import xray_backends
from moto.iot import iot_backends
from moto.iotdata import iotdata_backends
from moto.batch import batch_backends
from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends
from moto.config import config_backends
BACKENDS = {
'acm': acm_backends,
'apigateway': apigateway_backends,
'autoscaling': autoscaling_backends,
'batch': batch_backends,
'cloudformation': cloudformation_backends,
'cloudwatch': cloudwatch_backends,
'cognito-identity': cognitoidentity_backends,
'cognito-idp': cognitoidp_backends,
'config': config_backends,
'datapipeline': datapipeline_backends,
'dynamodb': dynamodb_backends,
'dynamodb2': dynamodb_backends2,
'dynamodbstreams': dynamodbstreams_backends,
'ec2': ec2_backends,
'ecr': ecr_backends,
'ecs': ecs_backends,
'elb': elb_backends,
'elbv2': elbv2_backends,
'events': events_backends,
'emr': emr_backends,
'glacier': glacier_backends,
'glue': glue_backends,
'iam': iam_backends,
'moto_api': moto_api_backends,
'instance_metadata': instance_metadata_backends,
'logs': logs_backends,
'kinesis': kinesis_backends,
'kms': kms_backends,
'opsworks': opsworks_backends,
'organizations': organizations_backends,
'polly': polly_backends,
'redshift': redshift_backends,
'resource-groups': resourcegroups_backends,
'rds': rds2_backends,
's3': s3_backends,
's3bucket_path': s3_backends,
'ses': ses_backends,
'secretsmanager': secretsmanager_backends,
'sns': sns_backends,
'sqs': sqs_backends,
'ssm': ssm_backends,
'sts': sts_backends,
'swf': swf_backends,
'route53': route53_backends,
'lambda': lambda_backends,
'xray': xray_backends,
'resourcegroupstaggingapi': resourcegroupstaggingapi_backends,
'iot': iot_backends,
'iot-data': iotdata_backends,
"acm": acm_backends,
"apigateway": apigateway_backends,
"athena": athena_backends,
"autoscaling": autoscaling_backends,
"batch": batch_backends,
"cloudformation": cloudformation_backends,
"cloudwatch": cloudwatch_backends,
"codecommit": codecommit_backends,
"codepipeline": codepipeline_backends,
"cognito-identity": cognitoidentity_backends,
"cognito-idp": cognitoidp_backends,
"config": config_backends,
"datapipeline": datapipeline_backends,
"datasync": datasync_backends,
"dynamodb": dynamodb_backends,
"dynamodb2": dynamodb_backends2,
"dynamodbstreams": dynamodbstreams_backends,
"ec2": ec2_backends,
"ec2_instance_connect": ec2_instance_connect_backends,
"ecr": ecr_backends,
"ecs": ecs_backends,
"elb": elb_backends,
"elbv2": elbv2_backends,
"events": events_backends,
"emr": emr_backends,
"glacier": glacier_backends,
"glue": glue_backends,
"iam": iam_backends,
"moto_api": moto_api_backends,
"instance_metadata": instance_metadata_backends,
"logs": logs_backends,
"kinesis": kinesis_backends,
"kms": kms_backends,
"opsworks": opsworks_backends,
"organizations": organizations_backends,
"polly": polly_backends,
"redshift": redshift_backends,
"resource-groups": resourcegroups_backends,
"rds": rds2_backends,
"s3": s3_backends,
"s3bucket_path": s3_backends,
"ses": ses_backends,
"secretsmanager": secretsmanager_backends,
"sns": sns_backends,
"sqs": sqs_backends,
"ssm": ssm_backends,
"stepfunctions": stepfunction_backends,
"sts": sts_backends,
"swf": swf_backends,
"route53": route53_backends,
"lambda": lambda_backends,
"xray": xray_backends,
"resourcegroupstaggingapi": resourcegroupstaggingapi_backends,
"iot": iot_backends,
"iot-data": iotdata_backends,
}
@ -106,6 +118,6 @@ def get_model(name, region_name):
for backends in BACKENDS.values():
for region, backend in backends.items():
if region == region_name:
models = getattr(backend.__class__, '__models__', {})
models = getattr(backend.__class__, "__models__", {})
if name in models:
return list(getattr(backend, models[name])())

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import batch_backends
from ..core.models import base_decorator
batch_backend = batch_backends['us-east-1']
batch_backend = batch_backends["us-east-1"]
mock_batch = base_decorator(batch_backends)

View File

@ -12,26 +12,29 @@ class AWSError(Exception):
self.status = status if status is not None else self.STATUS
def response(self):
return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status)
return (
json.dumps({"__type": self.code, "message": self.message}),
dict(status=self.status),
)
class InvalidRequestException(AWSError):
CODE = 'InvalidRequestException'
CODE = "InvalidRequestException"
class InvalidParameterValueException(AWSError):
CODE = 'InvalidParameterValue'
CODE = "InvalidParameterValue"
class ValidationError(AWSError):
CODE = 'ValidationError'
CODE = "ValidationError"
class InternalFailure(AWSError):
CODE = 'InternalFailure'
CODE = "InternalFailure"
STATUS = 500
class ClientException(AWSError):
CODE = 'ClientException'
CODE = "ClientException"
STATUS = 400

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ import json
class BatchResponse(BaseResponse):
def _error(self, code, message):
return json.dumps({'__type': code, 'message': message}), dict(status=400)
return json.dumps({"__type": code, "message": message}), dict(status=400)
@property
def batch_backend(self):
@ -22,9 +22,9 @@ class BatchResponse(BaseResponse):
@property
def json(self):
if self.body is None or self.body == '':
if self.body is None or self.body == "":
self._json = {}
elif not hasattr(self, '_json'):
elif not hasattr(self, "_json"):
try:
self._json = json.loads(self.body)
except ValueError:
@ -39,153 +39,146 @@ class BatchResponse(BaseResponse):
def _get_action(self):
# Return element after the /v1/*
return urlsplit(self.uri).path.lstrip('/').split('/')[1]
return urlsplit(self.uri).path.lstrip("/").split("/")[1]
# CreateComputeEnvironment
def createcomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironmentName')
compute_resource = self._get_param('computeResources')
service_role = self._get_param('serviceRole')
state = self._get_param('state')
_type = self._get_param('type')
compute_env_name = self._get_param("computeEnvironmentName")
compute_resource = self._get_param("computeResources")
service_role = self._get_param("serviceRole")
state = self._get_param("state")
_type = self._get_param("type")
try:
name, arn = self.batch_backend.create_compute_environment(
compute_environment_name=compute_env_name,
_type=_type, state=state,
_type=_type,
state=state,
compute_resources=compute_resource,
service_role=service_role
service_role=service_role,
)
except AWSError as err:
return err.response()
result = {
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
return json.dumps(result)
# DescribeComputeEnvironments
def describecomputeenvironments(self):
compute_environments = self._get_param('computeEnvironments')
max_results = self._get_param('maxResults') # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored
compute_environments = self._get_param("computeEnvironments")
max_results = self._get_param("maxResults") # Ignored, should be int
next_token = self._get_param("nextToken") # Ignored
envs = self.batch_backend.describe_compute_environments(compute_environments, max_results=max_results, next_token=next_token)
envs = self.batch_backend.describe_compute_environments(
compute_environments, max_results=max_results, next_token=next_token
)
result = {'computeEnvironments': envs}
result = {"computeEnvironments": envs}
return json.dumps(result)
# DeleteComputeEnvironment
def deletecomputeenvironment(self):
compute_environment = self._get_param('computeEnvironment')
compute_environment = self._get_param("computeEnvironment")
try:
self.batch_backend.delete_compute_environment(compute_environment)
except AWSError as err:
return err.response()
return ''
return ""
# UpdateComputeEnvironment
def updatecomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironment')
compute_resource = self._get_param('computeResources')
service_role = self._get_param('serviceRole')
state = self._get_param('state')
compute_env_name = self._get_param("computeEnvironment")
compute_resource = self._get_param("computeResources")
service_role = self._get_param("serviceRole")
state = self._get_param("state")
try:
name, arn = self.batch_backend.update_compute_environment(
compute_environment_name=compute_env_name,
compute_resources=compute_resource,
service_role=service_role,
state=state
state=state,
)
except AWSError as err:
return err.response()
result = {
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
return json.dumps(result)
# CreateJobQueue
def createjobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder')
queue_name = self._get_param('jobQueueName')
priority = self._get_param('priority')
state = self._get_param('state')
compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueueName")
priority = self._get_param("priority")
state = self._get_param("state")
try:
name, arn = self.batch_backend.create_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order
compute_env_order=compute_env_order,
)
except AWSError as err:
return err.response()
result = {
'jobQueueArn': arn,
'jobQueueName': name
}
result = {"jobQueueArn": arn, "jobQueueName": name}
return json.dumps(result)
# DescribeJobQueues
def describejobqueues(self):
job_queues = self._get_param('jobQueues')
max_results = self._get_param('maxResults') # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored
job_queues = self._get_param("jobQueues")
max_results = self._get_param("maxResults") # Ignored, should be int
next_token = self._get_param("nextToken") # Ignored
queues = self.batch_backend.describe_job_queues(job_queues, max_results=max_results, next_token=next_token)
queues = self.batch_backend.describe_job_queues(
job_queues, max_results=max_results, next_token=next_token
)
result = {'jobQueues': queues}
result = {"jobQueues": queues}
return json.dumps(result)
# UpdateJobQueue
def updatejobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder')
queue_name = self._get_param('jobQueue')
priority = self._get_param('priority')
state = self._get_param('state')
compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueue")
priority = self._get_param("priority")
state = self._get_param("state")
try:
name, arn = self.batch_backend.update_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order
compute_env_order=compute_env_order,
)
except AWSError as err:
return err.response()
result = {
'jobQueueArn': arn,
'jobQueueName': name
}
result = {"jobQueueArn": arn, "jobQueueName": name}
return json.dumps(result)
# DeleteJobQueue
def deletejobqueue(self):
queue_name = self._get_param('jobQueue')
queue_name = self._get_param("jobQueue")
self.batch_backend.delete_job_queue(queue_name)
return ''
return ""
# RegisterJobDefinition
def registerjobdefinition(self):
container_properties = self._get_param('containerProperties')
def_name = self._get_param('jobDefinitionName')
parameters = self._get_param('parameters')
retry_strategy = self._get_param('retryStrategy')
_type = self._get_param('type')
container_properties = self._get_param("containerProperties")
def_name = self._get_param("jobDefinitionName")
parameters = self._get_param("parameters")
retry_strategy = self._get_param("retryStrategy")
_type = self._get_param("type")
try:
name, arn, revision = self.batch_backend.register_job_definition(
@ -193,104 +186,113 @@ class BatchResponse(BaseResponse):
parameters=parameters,
_type=_type,
retry_strategy=retry_strategy,
container_properties=container_properties
container_properties=container_properties,
)
except AWSError as err:
return err.response()
result = {
'jobDefinitionArn': arn,
'jobDefinitionName': name,
'revision': revision
"jobDefinitionArn": arn,
"jobDefinitionName": name,
"revision": revision,
}
return json.dumps(result)
# DeregisterJobDefinition
def deregisterjobdefinition(self):
queue_name = self._get_param('jobDefinition')
queue_name = self._get_param("jobDefinition")
self.batch_backend.deregister_job_definition(queue_name)
return ''
return ""
# DescribeJobDefinitions
def describejobdefinitions(self):
job_def_name = self._get_param('jobDefinitionName')
job_def_list = self._get_param('jobDefinitions')
max_results = self._get_param('maxResults')
next_token = self._get_param('nextToken')
status = self._get_param('status')
job_def_name = self._get_param("jobDefinitionName")
job_def_list = self._get_param("jobDefinitions")
max_results = self._get_param("maxResults")
next_token = self._get_param("nextToken")
status = self._get_param("status")
job_defs = self.batch_backend.describe_job_definitions(job_def_name, job_def_list, status, max_results, next_token)
job_defs = self.batch_backend.describe_job_definitions(
job_def_name, job_def_list, status, max_results, next_token
)
result = {'jobDefinitions': [job.describe() for job in job_defs]}
result = {"jobDefinitions": [job.describe() for job in job_defs]}
return json.dumps(result)
# SubmitJob
def submitjob(self):
container_overrides = self._get_param('containerOverrides')
depends_on = self._get_param('dependsOn')
job_def = self._get_param('jobDefinition')
job_name = self._get_param('jobName')
job_queue = self._get_param('jobQueue')
parameters = self._get_param('parameters')
retries = self._get_param('retryStrategy')
container_overrides = self._get_param("containerOverrides")
depends_on = self._get_param("dependsOn")
job_def = self._get_param("jobDefinition")
job_name = self._get_param("jobName")
job_queue = self._get_param("jobQueue")
parameters = self._get_param("parameters")
retries = self._get_param("retryStrategy")
try:
name, job_id = self.batch_backend.submit_job(
job_name, job_def, job_queue,
job_name,
job_def,
job_queue,
parameters=parameters,
retries=retries,
depends_on=depends_on,
container_overrides=container_overrides
container_overrides=container_overrides,
)
except AWSError as err:
return err.response()
result = {
'jobId': job_id,
'jobName': name,
}
result = {"jobId": job_id, "jobName": name}
return json.dumps(result)
# DescribeJobs
def describejobs(self):
jobs = self._get_param('jobs')
jobs = self._get_param("jobs")
try:
return json.dumps({'jobs': self.batch_backend.describe_jobs(jobs)})
return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
except AWSError as err:
return err.response()
# ListJobs
def listjobs(self):
job_queue = self._get_param('jobQueue')
job_status = self._get_param('jobStatus')
max_results = self._get_param('maxResults')
next_token = self._get_param('nextToken')
job_queue = self._get_param("jobQueue")
job_status = self._get_param("jobStatus")
max_results = self._get_param("maxResults")
next_token = self._get_param("nextToken")
try:
jobs = self.batch_backend.list_jobs(job_queue, job_status, max_results, next_token)
jobs = self.batch_backend.list_jobs(
job_queue, job_status, max_results, next_token
)
except AWSError as err:
return err.response()
result = {'jobSummaryList': [{'jobId': job.job_id, 'jobName': job.job_name} for job in jobs]}
result = {
"jobSummaryList": [
{"jobId": job.job_id, "jobName": job.job_name} for job in jobs
]
}
return json.dumps(result)
# TerminateJob
def terminatejob(self):
job_id = self._get_param('jobId')
reason = self._get_param('reason')
job_id = self._get_param("jobId")
reason = self._get_param("reason")
try:
self.batch_backend.terminate_job(job_id, reason)
except AWSError as err:
return err.response()
return ''
return ""
# CancelJob
def canceljob(self): # Theres some AWS semantics on the differences but for us they're identical ;-)
def canceljob(
self,
): # Theres some AWS semantics on the differences but for us they're identical ;-)
return self.terminatejob()

View File

@ -1,25 +1,23 @@
from __future__ import unicode_literals
from .responses import BatchResponse
url_bases = [
"https?://batch.(.+).amazonaws.com",
]
url_bases = ["https?://batch.(.+).amazonaws.com"]
url_paths = {
'{0}/v1/createcomputeenvironment$': BatchResponse.dispatch,
'{0}/v1/describecomputeenvironments$': BatchResponse.dispatch,
'{0}/v1/deletecomputeenvironment': BatchResponse.dispatch,
'{0}/v1/updatecomputeenvironment': BatchResponse.dispatch,
'{0}/v1/createjobqueue': BatchResponse.dispatch,
'{0}/v1/describejobqueues': BatchResponse.dispatch,
'{0}/v1/updatejobqueue': BatchResponse.dispatch,
'{0}/v1/deletejobqueue': BatchResponse.dispatch,
'{0}/v1/registerjobdefinition': BatchResponse.dispatch,
'{0}/v1/deregisterjobdefinition': BatchResponse.dispatch,
'{0}/v1/describejobdefinitions': BatchResponse.dispatch,
'{0}/v1/submitjob': BatchResponse.dispatch,
'{0}/v1/describejobs': BatchResponse.dispatch,
'{0}/v1/listjobs': BatchResponse.dispatch,
'{0}/v1/terminatejob': BatchResponse.dispatch,
'{0}/v1/canceljob': BatchResponse.dispatch,
"{0}/v1/createcomputeenvironment$": BatchResponse.dispatch,
"{0}/v1/describecomputeenvironments$": BatchResponse.dispatch,
"{0}/v1/deletecomputeenvironment": BatchResponse.dispatch,
"{0}/v1/updatecomputeenvironment": BatchResponse.dispatch,
"{0}/v1/createjobqueue": BatchResponse.dispatch,
"{0}/v1/describejobqueues": BatchResponse.dispatch,
"{0}/v1/updatejobqueue": BatchResponse.dispatch,
"{0}/v1/deletejobqueue": BatchResponse.dispatch,
"{0}/v1/registerjobdefinition": BatchResponse.dispatch,
"{0}/v1/deregisterjobdefinition": BatchResponse.dispatch,
"{0}/v1/describejobdefinitions": BatchResponse.dispatch,
"{0}/v1/submitjob": BatchResponse.dispatch,
"{0}/v1/describejobs": BatchResponse.dispatch,
"{0}/v1/listjobs": BatchResponse.dispatch,
"{0}/v1/terminatejob": BatchResponse.dispatch,
"{0}/v1/canceljob": BatchResponse.dispatch,
}

View File

@ -2,7 +2,9 @@ from __future__ import unicode_literals
def make_arn_for_compute_env(account_id, name, region_name):
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(region_name, account_id, name)
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(
region_name, account_id, name
)
def make_arn_for_job_queue(account_id, name, region_name):
@ -10,7 +12,9 @@ def make_arn_for_job_queue(account_id, name, region_name):
def make_arn_for_task_def(account_id, name, revision, region_name):
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(region_name, account_id, name, revision)
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(
region_name, account_id, name, revision
)
def lowercase_first_key(some_dict):

View File

@ -2,7 +2,6 @@ from __future__ import unicode_literals
from .models import cloudformation_backends
from ..core.models import base_decorator, deprecated_base_decorator
cloudformation_backend = cloudformation_backends['us-east-1']
cloudformation_backend = cloudformation_backends["us-east-1"]
mock_cloudformation = base_decorator(cloudformation_backends)
mock_cloudformation_deprecated = deprecated_base_decorator(
cloudformation_backends)
mock_cloudformation_deprecated = deprecated_base_decorator(cloudformation_backends)

View File

@ -4,26 +4,23 @@ from jinja2 import Template
class UnformattedGetAttTemplateException(Exception):
description = 'Template error: resource {0} does not support attribute type {1} in Fn::GetAtt'
description = (
"Template error: resource {0} does not support attribute type {1} in Fn::GetAtt"
)
status_code = 400
class ValidationError(BadRequest):
def __init__(self, name_or_id, message=None):
if message is None:
message = "Stack with id {0} does not exist".format(name_or_id)
template = Template(ERROR_RESPONSE)
super(ValidationError, self).__init__()
self.description = template.render(
code="ValidationError",
message=message,
)
self.description = template.render(code="ValidationError", message=message)
class MissingParameterError(BadRequest):
def __init__(self, parameter_name):
template = Template(ERROR_RESPONSE)
super(MissingParameterError, self).__init__()
@ -40,8 +37,8 @@ class ExportNotFound(BadRequest):
template = Template(ERROR_RESPONSE)
super(ExportNotFound, self).__init__()
self.description = template.render(
code='ExportNotFound',
message="No export named {0} found.".format(export_name)
code="ExportNotFound",
message="No export named {0} found.".format(export_name),
)

View File

@ -4,9 +4,11 @@ import json
import yaml
import uuid
import boto.cloudformation
from boto3 import Session
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .parsing import ResourceMap, OutputMap
from .utils import (
@ -21,11 +23,19 @@ from .exceptions import ValidationError
class FakeStackSet(BaseModel):
def __init__(self, stackset_id, name, template, region='us-east-1',
status='ACTIVE', description=None, parameters=None, tags=None,
admin_role='AWSCloudFormationStackSetAdministrationRole',
execution_role='AWSCloudFormationStackSetExecutionRole'):
def __init__(
self,
stackset_id,
name,
template,
region="us-east-1",
status="ACTIVE",
description=None,
parameters=None,
tags=None,
admin_role="AWSCloudFormationStackSetAdministrationRole",
execution_role="AWSCloudFormationStackSetExecutionRole",
):
self.id = stackset_id
self.arn = generate_stackset_arn(stackset_id, region)
self.name = name
@ -42,12 +52,14 @@ class FakeStackSet(BaseModel):
def _create_operation(self, operation_id, action, status, accounts=[], regions=[]):
operation = {
'OperationId': str(operation_id),
'Action': action,
'Status': status,
'CreationTimestamp': datetime.now(),
'EndTimestamp': datetime.now() + timedelta(minutes=2),
'Instances': [{account: region} for account in accounts for region in regions],
"OperationId": str(operation_id),
"Action": action,
"Status": status,
"CreationTimestamp": datetime.now(),
"EndTimestamp": datetime.now() + timedelta(minutes=2),
"Instances": [
{account: region} for account in accounts for region in regions
],
}
self.operations += [operation]
@ -55,20 +67,30 @@ class FakeStackSet(BaseModel):
def get_operation(self, operation_id):
for operation in self.operations:
if operation_id == operation['OperationId']:
if operation_id == operation["OperationId"]:
return operation
raise ValidationError(operation_id)
def update_operation(self, operation_id, status):
operation = self.get_operation(operation_id)
operation['Status'] = status
operation["Status"] = status
return operation_id
def delete(self):
self.status = 'DELETED'
self.status = "DELETED"
def update(self, template, description, parameters, tags, admin_role,
execution_role, accounts, regions, operation_id=None):
def update(
self,
template,
description,
parameters,
tags,
admin_role,
execution_role,
accounts,
regions,
operation_id=None,
):
if not operation_id:
operation_id = uuid.uuid4()
@ -82,9 +104,13 @@ class FakeStackSet(BaseModel):
if accounts and regions:
self.update_instances(accounts, regions, self.parameters)
operation = self._create_operation(operation_id=operation_id,
action='UPDATE', status='SUCCEEDED', accounts=accounts,
regions=regions)
operation = self._create_operation(
operation_id=operation_id,
action="UPDATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation
def create_stack_instances(self, accounts, regions, parameters, operation_id=None):
@ -94,8 +120,13 @@ class FakeStackSet(BaseModel):
parameters = self.parameters
self.instances.create_instances(accounts, regions, parameters, operation_id)
self._create_operation(operation_id=operation_id, action='CREATE',
status='SUCCEEDED', accounts=accounts, regions=regions)
self._create_operation(
operation_id=operation_id,
action="CREATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
def delete_stack_instances(self, accounts, regions, operation_id=None):
if not operation_id:
@ -103,8 +134,13 @@ class FakeStackSet(BaseModel):
self.instances.delete(accounts, regions)
operation = self._create_operation(operation_id=operation_id, action='DELETE',
status='SUCCEEDED', accounts=accounts, regions=regions)
operation = self._create_operation(
operation_id=operation_id,
action="DELETE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation
def update_instances(self, accounts, regions, parameters, operation_id=None):
@ -112,9 +148,13 @@ class FakeStackSet(BaseModel):
operation_id = uuid.uuid4()
self.instances.update(accounts, regions, parameters)
operation = self._create_operation(operation_id=operation_id,
action='UPDATE', status='SUCCEEDED', accounts=accounts,
regions=regions)
operation = self._create_operation(
operation_id=operation_id,
action="UPDATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation
@ -131,12 +171,12 @@ class FakeStackInstances(BaseModel):
for region in regions:
for account in accounts:
instance = {
'StackId': generate_stack_id(self.stack_name, region, account),
'StackSetId': self.stackset_id,
'Region': region,
'Account': account,
'Status': "CURRENT",
'ParameterOverrides': parameters if parameters else [],
"StackId": generate_stack_id(self.stack_name, region, account),
"StackSetId": self.stackset_id,
"Region": region,
"Account": account,
"Status": "CURRENT",
"ParameterOverrides": parameters if parameters else [],
}
new_instances.append(instance)
self.stack_instances += new_instances
@ -147,24 +187,35 @@ class FakeStackInstances(BaseModel):
for region in regions:
instance = self.get_instance(account, region)
if parameters:
instance['ParameterOverrides'] = parameters
instance["ParameterOverrides"] = parameters
else:
instance['ParameterOverrides'] = []
instance["ParameterOverrides"] = []
def delete(self, accounts, regions):
for i, instance in enumerate(self.stack_instances):
if instance['Region'] in regions and instance['Account'] in accounts:
if instance["Region"] in regions and instance["Account"] in accounts:
self.stack_instances.pop(i)
def get_instance(self, account, region):
for i, instance in enumerate(self.stack_instances):
if instance['Region'] == region and instance['Account'] == account:
if instance["Region"] == region and instance["Account"] == account:
return self.stack_instances[i]
class FakeStack(BaseModel):
def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None, create_change_set=False):
def __init__(
self,
stack_id,
name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
cross_stack_resources=None,
create_change_set=False,
):
self.stack_id = stack_id
self.name = name
self.template = template
@ -176,22 +227,32 @@ class FakeStack(BaseModel):
self.tags = tags if tags else {}
self.events = []
if create_change_set:
self._add_stack_event("REVIEW_IN_PROGRESS",
resource_status_reason="User Initiated")
self._add_stack_event(
"REVIEW_IN_PROGRESS", resource_status_reason="User Initiated"
)
else:
self._add_stack_event("CREATE_IN_PROGRESS",
resource_status_reason="User Initiated")
self._add_stack_event(
"CREATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.description = self.template_dict.get('Description')
self.description = self.template_dict.get("Description")
self.cross_stack_resources = cross_stack_resources or {}
self.resource_map = self._create_resource_map()
self.output_map = self._create_output_map()
self._add_stack_event("CREATE_COMPLETE")
self.status = 'CREATE_COMPLETE'
self.status = "CREATE_COMPLETE"
self.creation_time = datetime.utcnow()
def _create_resource_map(self):
resource_map = ResourceMap(
self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict, self.cross_stack_resources)
self.stack_id,
self.name,
self.parameters,
self.tags,
self.region_name,
self.template_dict,
self.cross_stack_resources,
)
resource_map.create()
return resource_map
@ -200,34 +261,50 @@ class FakeStack(BaseModel):
output_map.create()
return output_map
def _add_stack_event(self, resource_status, resource_status_reason=None, resource_properties=None):
self.events.append(FakeEvent(
stack_id=self.stack_id,
stack_name=self.name,
logical_resource_id=self.name,
physical_resource_id=self.stack_id,
resource_type="AWS::CloudFormation::Stack",
resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
))
@property
def creation_time_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.creation_time)
def _add_resource_event(self, logical_resource_id, resource_status, resource_status_reason=None, resource_properties=None):
def _add_stack_event(
self, resource_status, resource_status_reason=None, resource_properties=None
):
self.events.append(
FakeEvent(
stack_id=self.stack_id,
stack_name=self.name,
logical_resource_id=self.name,
physical_resource_id=self.stack_id,
resource_type="AWS::CloudFormation::Stack",
resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
)
)
def _add_resource_event(
self,
logical_resource_id,
resource_status,
resource_status_reason=None,
resource_properties=None,
):
# not used yet... feel free to help yourself
resource = self.resource_map[logical_resource_id]
self.events.append(FakeEvent(
stack_id=self.stack_id,
stack_name=self.name,
logical_resource_id=logical_resource_id,
physical_resource_id=resource.physical_resource_id,
resource_type=resource.type,
resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
))
self.events.append(
FakeEvent(
stack_id=self.stack_id,
stack_name=self.name,
logical_resource_id=logical_resource_id,
physical_resource_id=resource.physical_resource_id,
resource_type=resource.type,
resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
)
)
def _parse_template(self):
yaml.add_multi_constructor('', yaml_tag_constructor)
yaml.add_multi_constructor("", yaml_tag_constructor)
try:
self.template_dict = yaml.load(self.template, Loader=yaml.Loader)
except yaml.parser.ParserError:
@ -250,7 +327,9 @@ class FakeStack(BaseModel):
return self.output_map.exports
def update(self, template, role_arn=None, parameters=None, tags=None):
self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated")
self._add_stack_event(
"UPDATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.template = template
self._parse_template()
self.resource_map.update(self.template_dict, parameters)
@ -264,15 +343,15 @@ class FakeStack(BaseModel):
# TODO: update tags in the resource map
def delete(self):
self._add_stack_event("DELETE_IN_PROGRESS",
resource_status_reason="User Initiated")
self._add_stack_event(
"DELETE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.resource_map.delete()
self._add_stack_event("DELETE_COMPLETE")
self.status = "DELETE_COMPLETE"
class FakeChange(BaseModel):
def __init__(self, action, logical_resource_id, resource_type):
self.action = action
self.logical_resource_id = logical_resource_id
@ -280,8 +359,21 @@ class FakeChange(BaseModel):
class FakeChangeSet(FakeStack):
def __init__(self, stack_id, stack_name, stack_template, change_set_id, change_set_name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None):
def __init__(
self,
stack_id,
stack_name,
stack_template,
change_set_id,
change_set_name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
cross_stack_resources=None,
):
super(FakeChangeSet, self).__init__(
stack_id,
stack_name,
@ -306,17 +398,28 @@ class FakeChangeSet(FakeStack):
resources_by_action = self.resource_map.diff(self.template_dict, parameters)
for action, resources in resources_by_action.items():
for resource_name, resource in resources.items():
changes.append(FakeChange(
action=action,
logical_resource_id=resource_name,
resource_type=resource['ResourceType'],
))
changes.append(
FakeChange(
action=action,
logical_resource_id=resource_name,
resource_type=resource["ResourceType"],
)
)
return changes
class FakeEvent(BaseModel):
def __init__(self, stack_id, stack_name, logical_resource_id, physical_resource_id, resource_type, resource_status, resource_status_reason=None, resource_properties=None):
def __init__(
self,
stack_id,
stack_name,
logical_resource_id,
physical_resource_id,
resource_type,
resource_status,
resource_status_reason=None,
resource_properties=None,
):
self.stack_id = stack_id
self.stack_name = stack_name
self.logical_resource_id = logical_resource_id
@ -330,7 +433,6 @@ class FakeEvent(BaseModel):
class CloudFormationBackend(BaseBackend):
def __init__(self):
self.stacks = OrderedDict()
self.stacksets = OrderedDict()
@ -338,7 +440,17 @@ class CloudFormationBackend(BaseBackend):
self.exports = OrderedDict()
self.change_sets = OrderedDict()
def create_stack_set(self, name, template, parameters, tags=None, description=None, region='us-east-1', admin_role=None, execution_role=None):
def create_stack_set(
self,
name,
template,
parameters,
tags=None,
description=None,
region="us-east-1",
admin_role=None,
execution_role=None,
):
stackset_id = generate_stackset_id(name)
new_stackset = FakeStackSet(
stackset_id=stackset_id,
@ -366,7 +478,9 @@ class CloudFormationBackend(BaseBackend):
if self.stacksets[stackset].name == name:
self.stacksets[stackset].delete()
def create_stack_instances(self, stackset_name, accounts, regions, parameters, operation_id=None):
def create_stack_instances(
self, stackset_name, accounts, regions, parameters, operation_id=None
):
stackset = self.get_stack_set(stackset_name)
stackset.create_stack_instances(
@ -377,9 +491,19 @@ class CloudFormationBackend(BaseBackend):
)
return stackset
def update_stack_set(self, stackset_name, template=None, description=None,
parameters=None, tags=None, admin_role=None, execution_role=None,
accounts=None, regions=None, operation_id=None):
def update_stack_set(
self,
stackset_name,
template=None,
description=None,
parameters=None,
tags=None,
admin_role=None,
execution_role=None,
accounts=None,
regions=None,
operation_id=None,
):
stackset = self.get_stack_set(stackset_name)
update = stackset.update(
template=template,
@ -390,16 +514,28 @@ class CloudFormationBackend(BaseBackend):
execution_role=execution_role,
accounts=accounts,
regions=regions,
operation_id=operation_id
operation_id=operation_id,
)
return update
def delete_stack_instances(self, stackset_name, accounts, regions, operation_id=None):
def delete_stack_instances(
self, stackset_name, accounts, regions, operation_id=None
):
stackset = self.get_stack_set(stackset_name)
stackset.delete_stack_instances(accounts, regions, operation_id)
return stackset
def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False):
def create_stack(
self,
name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
create_change_set=False,
):
stack_id = generate_stack_id(name)
new_stack = FakeStack(
stack_id=stack_id,
@ -419,10 +555,21 @@ class CloudFormationBackend(BaseBackend):
self.exports[export.name] = export
return new_stack
def create_change_set(self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None):
def create_change_set(
self,
stack_name,
change_set_name,
template,
parameters,
region_name,
change_set_type,
notification_arns=None,
tags=None,
role_arn=None,
):
stack_id = None
stack_template = None
if change_set_type == 'UPDATE':
if change_set_type == "UPDATE":
stacks = self.stacks.values()
stack = None
for s in stacks:
@ -449,7 +596,7 @@ class CloudFormationBackend(BaseBackend):
notification_arns=notification_arns,
tags=tags,
role_arn=role_arn,
cross_stack_resources=self.exports
cross_stack_resources=self.exports,
)
self.change_sets[change_set_id] = new_change_set
self.stacks[stack_id] = new_change_set
@ -488,11 +635,11 @@ class CloudFormationBackend(BaseBackend):
stack = self.change_sets[cs]
if stack is None:
raise ValidationError(stack_name)
if stack.events[-1].resource_status == 'REVIEW_IN_PROGRESS':
stack._add_stack_event('CREATE_COMPLETE')
if stack.events[-1].resource_status == "REVIEW_IN_PROGRESS":
stack._add_stack_event("CREATE_COMPLETE")
else:
stack._add_stack_event('UPDATE_IN_PROGRESS')
stack._add_stack_event('UPDATE_COMPLETE')
stack._add_stack_event("UPDATE_IN_PROGRESS")
stack._add_stack_event("UPDATE_COMPLETE")
return True
def describe_stacks(self, name_or_stack_id):
@ -514,9 +661,7 @@ class CloudFormationBackend(BaseBackend):
return self.change_sets.values()
def list_stacks(self):
return [
v for v in self.stacks.values()
] + [
return [v for v in self.stacks.values()] + [
v for v in self.deleted_stacks.values()
]
@ -538,6 +683,8 @@ class CloudFormationBackend(BaseBackend):
def list_stack_resources(self, stack_name_or_id):
stack = self.get_stack(stack_name_or_id)
if stack is None:
return None
return stack.stack_resources
def delete_stack(self, name_or_stack_id):
@ -558,10 +705,10 @@ class CloudFormationBackend(BaseBackend):
all_exports = list(self.exports.values())
if token is None:
exports = all_exports[0:100]
next_token = '100' if len(all_exports) > 100 else None
next_token = "100" if len(all_exports) > 100 else None
else:
token = int(token)
exports = all_exports[token:token + 100]
exports = all_exports[token : token + 100]
next_token = str(token + 100) if len(all_exports) > token + 100 else None
return exports, next_token
@ -572,9 +719,20 @@ class CloudFormationBackend(BaseBackend):
new_stack_export_names = [x.name for x in stack.exports]
export_names = self.exports.keys()
if not set(export_names).isdisjoint(new_stack_export_names):
raise ValidationError(stack.stack_id, message='Export names must be unique across a given region')
raise ValidationError(
stack.stack_id,
message="Export names must be unique across a given region",
)
cloudformation_backends = {}
for region in boto.cloudformation.regions():
cloudformation_backends[region.name] = CloudFormationBackend()
for region in Session().get_available_regions("cloudformation"):
cloudformation_backends[region] = CloudFormationBackend()
for region in Session().get_available_regions(
"cloudformation", partition_name="aws-us-gov"
):
cloudformation_backends[region] = CloudFormationBackend()
for region in Session().get_available_regions(
"cloudformation", partition_name="aws-cn"
):
cloudformation_backends[region] = CloudFormationBackend()

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals
import collections
import functools
import json
import logging
import copy
import warnings
@ -11,12 +11,14 @@ from moto.awslambda import models as lambda_models
from moto.batch import models as batch_models
from moto.cloudwatch import models as cloudwatch_models
from moto.cognitoidentity import models as cognitoidentity_models
from moto.compat import collections_abc
from moto.datapipeline import models as datapipeline_models
from moto.dynamodb2 import models as dynamodb2_models
from moto.ec2 import models as ec2_models
from moto.ecs import models as ecs_models
from moto.elb import models as elb_models
from moto.elbv2 import models as elbv2_models
from moto.events import models as events_models
from moto.iam import models as iam_models
from moto.kinesis import models as kinesis_models
from moto.kms import models as kms_models
@ -24,11 +26,18 @@ from moto.rds import models as rds_models
from moto.rds2 import models as rds2_models
from moto.redshift import models as redshift_models
from moto.route53 import models as route53_models
from moto.s3 import models as s3_models
from moto.s3 import models as s3_models, s3_backend
from moto.s3.utils import bucket_and_name_from_url
from moto.sns import models as sns_models
from moto.sqs import models as sqs_models
from moto.core import ACCOUNT_ID
from .utils import random_suffix
from .exceptions import ExportNotFound, MissingParameterError, UnformattedGetAttTemplateException, ValidationError
from .exceptions import (
ExportNotFound,
MissingParameterError,
UnformattedGetAttTemplateException,
ValidationError,
)
from boto.cloudformation.stack import Output
MODEL_MAP = {
@ -86,6 +95,7 @@ MODEL_MAP = {
"AWS::SNS::Topic": sns_models.Topic,
"AWS::S3::Bucket": s3_models.FakeBucket,
"AWS::SQS::Queue": sqs_models.Queue,
"AWS::Events::Rule": events_models.Rule,
}
# http://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-name.html
@ -100,7 +110,7 @@ NAME_TYPE_MAP = {
"AWS::RDS::DBInstance": "DBInstanceIdentifier",
"AWS::S3::Bucket": "BucketName",
"AWS::SNS::Topic": "TopicName",
"AWS::SQS::Queue": "QueueName"
"AWS::SQS::Queue": "QueueName",
}
# Just ignore these models types for now
@ -109,13 +119,12 @@ NULL_MODELS = [
"AWS::CloudFormation::WaitConditionHandle",
]
DEFAULT_REGION = 'us-east-1'
DEFAULT_REGION = "us-east-1"
logger = logging.getLogger("moto")
class LazyDict(dict):
def __getitem__(self, key):
val = dict.__getitem__(self, key)
if callable(val):
@ -132,10 +141,10 @@ def clean_json(resource_json, resources_map):
Eventually, this is where we would add things like function parsing (fn::)
"""
if isinstance(resource_json, dict):
if 'Ref' in resource_json:
if "Ref" in resource_json:
# Parse resource reference
resource = resources_map[resource_json['Ref']]
if hasattr(resource, 'physical_resource_id'):
resource = resources_map[resource_json["Ref"]]
if hasattr(resource, "physical_resource_id"):
return resource.physical_resource_id
else:
return resource
@ -145,77 +154,98 @@ def clean_json(resource_json, resources_map):
map_path = resource_json["Fn::FindInMap"][1:]
result = resources_map[map_name]
for path in map_path:
result = result[clean_json(path, resources_map)]
if "Fn::Transform" in result:
result = resources_map[clean_json(path, resources_map)]
else:
result = result[clean_json(path, resources_map)]
return result
if 'Fn::GetAtt' in resource_json:
resource = resources_map.get(resource_json['Fn::GetAtt'][0])
if "Fn::GetAtt" in resource_json:
resource = resources_map.get(resource_json["Fn::GetAtt"][0])
if resource is None:
return resource_json
try:
return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1])
return resource.get_cfn_attribute(resource_json["Fn::GetAtt"][1])
except NotImplementedError as n:
logger.warning(str(n).format(
resource_json['Fn::GetAtt'][0]))
logger.warning(str(n).format(resource_json["Fn::GetAtt"][0]))
except UnformattedGetAttTemplateException:
raise ValidationError(
'Bad Request',
"Bad Request",
UnformattedGetAttTemplateException.description.format(
resource_json['Fn::GetAtt'][0], resource_json['Fn::GetAtt'][1]))
resource_json["Fn::GetAtt"][0], resource_json["Fn::GetAtt"][1]
),
)
if 'Fn::If' in resource_json:
condition_name, true_value, false_value = resource_json['Fn::If']
if "Fn::If" in resource_json:
condition_name, true_value, false_value = resource_json["Fn::If"]
if resources_map.lazy_condition_map[condition_name]:
return clean_json(true_value, resources_map)
else:
return clean_json(false_value, resources_map)
if 'Fn::Join' in resource_json:
join_list = clean_json(resource_json['Fn::Join'][1], resources_map)
return resource_json['Fn::Join'][0].join([str(x) for x in join_list])
if "Fn::Join" in resource_json:
join_list = clean_json(resource_json["Fn::Join"][1], resources_map)
return resource_json["Fn::Join"][0].join([str(x) for x in join_list])
if 'Fn::Split' in resource_json:
to_split = clean_json(resource_json['Fn::Split'][1], resources_map)
return to_split.split(resource_json['Fn::Split'][0])
if "Fn::Split" in resource_json:
to_split = clean_json(resource_json["Fn::Split"][1], resources_map)
return to_split.split(resource_json["Fn::Split"][0])
if 'Fn::Select' in resource_json:
select_index = int(resource_json['Fn::Select'][0])
select_list = clean_json(resource_json['Fn::Select'][1], resources_map)
if "Fn::Select" in resource_json:
select_index = int(resource_json["Fn::Select"][0])
select_list = clean_json(resource_json["Fn::Select"][1], resources_map)
return select_list[select_index]
if 'Fn::Sub' in resource_json:
if isinstance(resource_json['Fn::Sub'], list):
if "Fn::Sub" in resource_json:
if isinstance(resource_json["Fn::Sub"], list):
warnings.warn(
"Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation")
"Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation"
)
else:
fn_sub_value = clean_json(resource_json['Fn::Sub'], resources_map)
to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value)
literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value)
fn_sub_value = clean_json(resource_json["Fn::Sub"], resources_map)
to_sub = re.findall(r'(?=\${)[^!^"]*?}', fn_sub_value)
literals = re.findall(r'(?=\${!)[^"]*?}', fn_sub_value)
for sub in to_sub:
if '.' in sub:
cleaned_ref = clean_json({'Fn::GetAtt': re.findall('(?<=\${)[^"]*?(?=})', sub)[0].split('.')}, resources_map)
if "." in sub:
cleaned_ref = clean_json(
{
"Fn::GetAtt": re.findall(r'(?<=\${)[^"]*?(?=})', sub)[
0
].split(".")
},
resources_map,
)
else:
cleaned_ref = clean_json({'Ref': re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, resources_map)
cleaned_ref = clean_json(
{"Ref": re.findall(r'(?<=\${)[^"]*?(?=})', sub)[0]},
resources_map,
)
fn_sub_value = fn_sub_value.replace(sub, cleaned_ref)
for literal in literals:
fn_sub_value = fn_sub_value.replace(literal, literal.replace('!', ''))
fn_sub_value = fn_sub_value.replace(
literal, literal.replace("!", "")
)
return fn_sub_value
pass
if 'Fn::ImportValue' in resource_json:
cleaned_val = clean_json(resource_json['Fn::ImportValue'], resources_map)
values = [x.value for x in resources_map.cross_stack_resources.values() if x.name == cleaned_val]
if "Fn::ImportValue" in resource_json:
cleaned_val = clean_json(resource_json["Fn::ImportValue"], resources_map)
values = [
x.value
for x in resources_map.cross_stack_resources.values()
if x.name == cleaned_val
]
if any(values):
return values[0]
else:
raise ExportNotFound(cleaned_val)
if 'Fn::GetAZs' in resource_json:
region = resource_json.get('Fn::GetAZs') or DEFAULT_REGION
if "Fn::GetAZs" in resource_json:
region = resource_json.get("Fn::GetAZs") or DEFAULT_REGION
result = []
# TODO: make this configurable, to reflect the real AWS AZs
for az in ('a', 'b', 'c', 'd'):
result.append('%s%s' % (region, az))
for az in ("a", "b", "c", "d"):
result.append("%s%s" % (region, az))
return result
cleaned_json = {}
@ -246,58 +276,69 @@ def resource_name_property_from_type(resource_type):
def generate_resource_name(resource_type, stack_name, logical_id):
if resource_type in ["AWS::ElasticLoadBalancingV2::TargetGroup",
"AWS::ElasticLoadBalancingV2::LoadBalancer"]:
if resource_type in [
"AWS::ElasticLoadBalancingV2::TargetGroup",
"AWS::ElasticLoadBalancingV2::LoadBalancer",
]:
# Target group names need to be less than 32 characters, so when cloudformation creates a name for you
# it makes sure to stay under that limit
name_prefix = '{0}-{1}'.format(stack_name, logical_id)
name_prefix = "{0}-{1}".format(stack_name, logical_id)
my_random_suffix = random_suffix()
truncated_name_prefix = name_prefix[0:32 - (len(my_random_suffix) + 1)]
truncated_name_prefix = name_prefix[0 : 32 - (len(my_random_suffix) + 1)]
# if the truncated name ends in a dash, we'll end up with a double dash in the final name, which is
# not allowed
if truncated_name_prefix.endswith('-'):
if truncated_name_prefix.endswith("-"):
truncated_name_prefix = truncated_name_prefix[:-1]
return '{0}-{1}'.format(truncated_name_prefix, my_random_suffix)
return "{0}-{1}".format(truncated_name_prefix, my_random_suffix)
else:
return '{0}-{1}-{2}'.format(stack_name, logical_id, random_suffix())
return "{0}-{1}-{2}".format(stack_name, logical_id, random_suffix())
def parse_resource(logical_id, resource_json, resources_map):
resource_type = resource_json['Type']
resource_type = resource_json["Type"]
resource_class = resource_class_from_type(resource_type)
if not resource_class:
warnings.warn(
"Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(resource_type))
"Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(
resource_type
)
)
return None
resource_json = clean_json(resource_json, resources_map)
resource_name_property = resource_name_property_from_type(resource_type)
if resource_name_property:
if 'Properties' not in resource_json:
resource_json['Properties'] = dict()
if resource_name_property not in resource_json['Properties']:
resource_json['Properties'][resource_name_property] = generate_resource_name(
resource_type, resources_map.get('AWS::StackName'), logical_id)
resource_name = resource_json['Properties'][resource_name_property]
if "Properties" not in resource_json:
resource_json["Properties"] = dict()
if resource_name_property not in resource_json["Properties"]:
resource_json["Properties"][
resource_name_property
] = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
resource_name = resource_json["Properties"][resource_name_property]
else:
resource_name = generate_resource_name(resource_type, resources_map.get('AWS::StackName'), logical_id)
resource_name = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
return resource_class, resource_json, resource_name
def parse_and_create_resource(logical_id, resource_json, resources_map, region_name):
condition = resource_json.get('Condition')
condition = resource_json.get("Condition")
if condition and not resources_map.lazy_condition_map[condition]:
# If this has a False condition, don't create the resource
return None
resource_type = resource_json['Type']
resource_type = resource_json["Type"]
resource_tuple = parse_resource(logical_id, resource_json, resources_map)
if not resource_tuple:
return None
resource_class, resource_json, resource_name = resource_tuple
resource = resource_class.create_from_cloudformation_json(
resource_name, resource_json, region_name)
resource_name, resource_json, region_name
)
resource.type = resource_type
resource.logical_resource_id = logical_id
return resource
@ -305,24 +346,27 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n
def parse_and_update_resource(logical_id, resource_json, resources_map, region_name):
resource_class, new_resource_json, new_resource_name = parse_resource(
logical_id, resource_json, resources_map)
logical_id, resource_json, resources_map
)
original_resource = resources_map[logical_id]
new_resource = resource_class.update_from_cloudformation_json(
original_resource=original_resource,
new_resource_name=new_resource_name,
cloudformation_json=new_resource_json,
region_name=region_name
region_name=region_name,
)
new_resource.type = resource_json['Type']
new_resource.type = resource_json["Type"]
new_resource.logical_resource_id = logical_id
return new_resource
def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name):
resource_class, resource_json, resource_name = parse_resource(
logical_id, resource_json, resources_map)
logical_id, resource_json, resources_map
)
resource_class.delete_from_cloudformation_json(
resource_name, resource_json, region_name)
resource_name, resource_json, region_name
)
def parse_condition(condition, resources_map, condition_map):
@ -334,8 +378,8 @@ def parse_condition(condition, resources_map, condition_map):
condition_values = []
for value in list(condition.values())[0]:
# Check if we are referencing another Condition
if 'Condition' in value:
condition_values.append(condition_map[value['Condition']])
if "Condition" in value:
condition_values.append(condition_map[value["Condition"]])
else:
condition_values.append(clean_json(value, resources_map))
@ -344,36 +388,49 @@ def parse_condition(condition, resources_map, condition_map):
elif condition_operator == "Fn::Not":
return not parse_condition(condition_values[0], resources_map, condition_map)
elif condition_operator == "Fn::And":
return all([
parse_condition(condition_value, resources_map, condition_map)
for condition_value
in condition_values])
return all(
[
parse_condition(condition_value, resources_map, condition_map)
for condition_value in condition_values
]
)
elif condition_operator == "Fn::Or":
return any([
parse_condition(condition_value, resources_map, condition_map)
for condition_value
in condition_values])
return any(
[
parse_condition(condition_value, resources_map, condition_map)
for condition_value in condition_values
]
)
def parse_output(output_logical_id, output_json, resources_map):
output_json = clean_json(output_json, resources_map)
output = Output()
output.key = output_logical_id
output.value = clean_json(output_json['Value'], resources_map)
output.description = output_json.get('Description')
output.value = clean_json(output_json["Value"], resources_map)
output.description = output_json.get("Description")
return output
class ResourceMap(collections.Mapping):
class ResourceMap(collections_abc.Mapping):
"""
This is a lazy loading map for resources. This allows us to create resources
without needing to create a full dependency tree. Upon creation, each
each resources is passed this lazy map that it can grab dependencies from.
"""
def __init__(self, stack_id, stack_name, parameters, tags, region_name, template, cross_stack_resources):
def __init__(
self,
stack_id,
stack_name,
parameters,
tags,
region_name,
template,
cross_stack_resources,
):
self._template = template
self._resource_json_map = template['Resources']
self._resource_json_map = template["Resources"]
self._region_name = region_name
self.input_parameters = parameters
self.tags = copy.deepcopy(tags)
@ -382,7 +439,7 @@ class ResourceMap(collections.Mapping):
# Create the default resources
self._parsed_resources = {
"AWS::AccountId": "123456789012",
"AWS::AccountId": ACCOUNT_ID,
"AWS::Region": self._region_name,
"AWS::StackId": stack_id,
"AWS::StackName": stack_name,
@ -401,7 +458,8 @@ class ResourceMap(collections.Mapping):
if not resource_json:
raise KeyError(resource_logical_id)
new_resource = parse_and_create_resource(
resource_logical_id, resource_json, self, self._region_name)
resource_logical_id, resource_json, self, self._region_name
)
if new_resource is not None:
self._parsed_resources[resource_logical_id] = new_resource
return new_resource
@ -417,13 +475,24 @@ class ResourceMap(collections.Mapping):
return self._resource_json_map.keys()
def load_mapping(self):
self._parsed_resources.update(self._template.get('Mappings', {}))
self._parsed_resources.update(self._template.get("Mappings", {}))
def transform_mapping(self):
for k, v in self._template.get("Mappings", {}).items():
if "Fn::Transform" in v:
name = v["Fn::Transform"]["Name"]
params = v["Fn::Transform"]["Parameters"]
if name == "AWS::Include":
location = params["Location"]
bucket_name, name = bucket_and_name_from_url(location)
key = s3_backend.get_key(bucket_name, name)
self._parsed_resources.update(json.loads(key.value))
def load_parameters(self):
parameter_slots = self._template.get('Parameters', {})
parameter_slots = self._template.get("Parameters", {})
for parameter_name, parameter in parameter_slots.items():
# Set the default values.
self.resolved_parameters[parameter_name] = parameter.get('Default')
self.resolved_parameters[parameter_name] = parameter.get("Default")
# Set any input parameters that were passed
self.no_echo_parameter_keys = []
@ -431,11 +500,11 @@ class ResourceMap(collections.Mapping):
if key in self.resolved_parameters:
parameter_slot = parameter_slots[key]
value_type = parameter_slot.get('Type', 'String')
if value_type == 'CommaDelimitedList' or value_type.startswith("List"):
value = value.split(',')
value_type = parameter_slot.get("Type", "String")
if value_type == "CommaDelimitedList" or value_type.startswith("List"):
value = value.split(",")
if parameter_slot.get('NoEcho'):
if parameter_slot.get("NoEcho"):
self.no_echo_parameter_keys.append(key)
self.resolved_parameters[key] = value
@ -449,29 +518,39 @@ class ResourceMap(collections.Mapping):
self._parsed_resources.update(self.resolved_parameters)
def load_conditions(self):
conditions = self._template.get('Conditions', {})
conditions = self._template.get("Conditions", {})
self.lazy_condition_map = LazyDict()
for condition_name, condition in conditions.items():
self.lazy_condition_map[condition_name] = functools.partial(parse_condition,
condition, self._parsed_resources, self.lazy_condition_map)
self.lazy_condition_map[condition_name] = functools.partial(
parse_condition,
condition,
self._parsed_resources,
self.lazy_condition_map,
)
for condition_name in self.lazy_condition_map:
self.lazy_condition_map[condition_name]
def create(self):
self.load_mapping()
self.transform_mapping()
self.load_parameters()
self.load_conditions()
# Since this is a lazy map, to create every object we just need to
# iterate through self.
self.tags.update({'aws:cloudformation:stack-name': self.get('AWS::StackName'),
'aws:cloudformation:stack-id': self.get('AWS::StackId')})
self.tags.update(
{
"aws:cloudformation:stack-name": self.get("AWS::StackName"),
"aws:cloudformation:stack-id": self.get("AWS::StackId"),
}
)
for resource in self.resources:
if isinstance(self[resource], ec2_models.TaggedEC2Resource):
self.tags['aws:cloudformation:logical-id'] = resource
self.tags["aws:cloudformation:logical-id"] = resource
ec2_models.ec2_backends[self._region_name].create_tags(
[self[resource].physical_resource_id], self.tags)
[self[resource].physical_resource_id], self.tags
)
def diff(self, template, parameters=None):
if parameters:
@ -481,36 +560,35 @@ class ResourceMap(collections.Mapping):
self.load_conditions()
old_template = self._resource_json_map
new_template = template['Resources']
new_template = template["Resources"]
resource_names_by_action = {
'Add': set(new_template) - set(old_template),
'Modify': set(name for name in new_template if name in old_template and new_template[
name] != old_template[name]),
'Remove': set(old_template) - set(new_template)
}
resources_by_action = {
'Add': {},
'Modify': {},
'Remove': {},
"Add": set(new_template) - set(old_template),
"Modify": set(
name
for name in new_template
if name in old_template and new_template[name] != old_template[name]
),
"Remove": set(old_template) - set(new_template),
}
resources_by_action = {"Add": {}, "Modify": {}, "Remove": {}}
for resource_name in resource_names_by_action['Add']:
resources_by_action['Add'][resource_name] = {
'LogicalResourceId': resource_name,
'ResourceType': new_template[resource_name]['Type']
for resource_name in resource_names_by_action["Add"]:
resources_by_action["Add"][resource_name] = {
"LogicalResourceId": resource_name,
"ResourceType": new_template[resource_name]["Type"],
}
for resource_name in resource_names_by_action['Modify']:
resources_by_action['Modify'][resource_name] = {
'LogicalResourceId': resource_name,
'ResourceType': new_template[resource_name]['Type']
for resource_name in resource_names_by_action["Modify"]:
resources_by_action["Modify"][resource_name] = {
"LogicalResourceId": resource_name,
"ResourceType": new_template[resource_name]["Type"],
}
for resource_name in resource_names_by_action['Remove']:
resources_by_action['Remove'][resource_name] = {
'LogicalResourceId': resource_name,
'ResourceType': old_template[resource_name]['Type']
for resource_name in resource_names_by_action["Remove"]:
resources_by_action["Remove"][resource_name] = {
"LogicalResourceId": resource_name,
"ResourceType": old_template[resource_name]["Type"],
}
return resources_by_action
@ -519,35 +597,38 @@ class ResourceMap(collections.Mapping):
resources_by_action = self.diff(template, parameters)
old_template = self._resource_json_map
new_template = template['Resources']
new_template = template["Resources"]
self._resource_json_map = new_template
for resource_name, resource in resources_by_action['Add'].items():
for resource_name, resource in resources_by_action["Add"].items():
resource_json = new_template[resource_name]
new_resource = parse_and_create_resource(
resource_name, resource_json, self, self._region_name)
resource_name, resource_json, self, self._region_name
)
self._parsed_resources[resource_name] = new_resource
for resource_name, resource in resources_by_action['Remove'].items():
for resource_name, resource in resources_by_action["Remove"].items():
resource_json = old_template[resource_name]
parse_and_delete_resource(
resource_name, resource_json, self, self._region_name)
resource_name, resource_json, self, self._region_name
)
self._parsed_resources.pop(resource_name)
tries = 1
while resources_by_action['Modify'] and tries < 5:
for resource_name, resource in resources_by_action['Modify'].copy().items():
while resources_by_action["Modify"] and tries < 5:
for resource_name, resource in resources_by_action["Modify"].copy().items():
resource_json = new_template[resource_name]
try:
changed_resource = parse_and_update_resource(
resource_name, resource_json, self, self._region_name)
resource_name, resource_json, self, self._region_name
)
except Exception as e:
# skip over dependency violations, and try again in a
# second pass
last_exception = e
else:
self._parsed_resources[resource_name] = changed_resource
del resources_by_action['Modify'][resource_name]
del resources_by_action["Modify"][resource_name]
tries += 1
if tries == 5:
raise last_exception
@ -559,7 +640,7 @@ class ResourceMap(collections.Mapping):
for resource in remaining_resources.copy():
parsed_resource = self._parsed_resources.get(resource)
try:
if parsed_resource and hasattr(parsed_resource, 'delete'):
if parsed_resource and hasattr(parsed_resource, "delete"):
parsed_resource.delete(self._region_name)
except Exception as e:
# skip over dependency violations, and try again in a
@ -572,12 +653,11 @@ class ResourceMap(collections.Mapping):
raise last_exception
class OutputMap(collections.Mapping):
class OutputMap(collections_abc.Mapping):
def __init__(self, resources, template, stack_id):
self._template = template
self._stack_id = stack_id
self._output_json_map = template.get('Outputs')
self._output_json_map = template.get("Outputs")
# Create the default resources
self._resource_map = resources
@ -591,7 +671,8 @@ class OutputMap(collections.Mapping):
else:
output_json = self._output_json_map.get(output_logical_id)
new_output = parse_output(
output_logical_id, output_json, self._resource_map)
output_logical_id, output_json, self._resource_map
)
self._parsed_outputs[output_logical_id] = new_output
return new_output
@ -610,9 +691,11 @@ class OutputMap(collections.Mapping):
exports = []
if self.outputs:
for key, value in self._output_json_map.items():
if value.get('Export'):
cleaned_name = clean_json(value['Export'].get('Name'), self._resource_map)
cleaned_value = clean_json(value.get('Value'), self._resource_map)
if value.get("Export"):
cleaned_name = clean_json(
value["Export"].get("Name"), self._resource_map
)
cleaned_value = clean_json(value.get("Value"), self._resource_map)
exports.append(Export(self._stack_id, cleaned_name, cleaned_value))
return exports
@ -622,7 +705,6 @@ class OutputMap(collections.Mapping):
class Export(object):
def __init__(self, exporting_stack_id, name, value):
self._exporting_stack_id = exporting_stack_id
self._name = name

View File

@ -7,12 +7,12 @@ from six.moves.urllib.parse import urlparse
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from moto.s3 import s3_backend
from moto.core import ACCOUNT_ID
from .models import cloudformation_backends
from .exceptions import ValidationError
class CloudFormationResponse(BaseResponse):
@property
def cloudformation_backend(self):
return cloudformation_backends[self.region]
@ -20,17 +20,18 @@ class CloudFormationResponse(BaseResponse):
def _get_stack_from_s3_url(self, template_url):
template_url_parts = urlparse(template_url)
if "localhost" in template_url:
bucket_name, key_name = template_url_parts.path.lstrip(
"/").split("/", 1)
bucket_name, key_name = template_url_parts.path.lstrip("/").split("/", 1)
else:
if template_url_parts.netloc.endswith('amazonaws.com') \
and template_url_parts.netloc.startswith('s3'):
if template_url_parts.netloc.endswith(
"amazonaws.com"
) and template_url_parts.netloc.startswith("s3"):
# Handle when S3 url uses amazon url with bucket in path
# Also handles getting region as technically s3 is region'd
# region = template_url.netloc.split('.')[1]
bucket_name, key_name = template_url_parts.path.lstrip(
"/").split("/", 1)
bucket_name, key_name = template_url_parts.path.lstrip("/").split(
"/", 1
)
else:
bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/")
@ -39,24 +40,26 @@ class CloudFormationResponse(BaseResponse):
return key.value.decode("utf-8")
def create_stack(self):
stack_name = self._get_param('StackName')
stack_body = self._get_param('TemplateBody')
template_url = self._get_param('TemplateURL')
role_arn = self._get_param('RoleARN')
stack_name = self._get_param("StackName")
stack_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
role_arn = self._get_param("RoleARN")
parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# Hack dict-comprehension
parameters = dict([
(parameter['parameter_key'], parameter['parameter_value'])
for parameter
in parameters_list
])
parameters = dict(
[
(parameter["parameter_key"], parameter["parameter_value"])
for parameter in parameters_list
]
)
if template_url:
stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param(
'NotificationARNs.member')
stack_notification_arns = self._get_multi_param("NotificationARNs.member")
stack = self.cloudformation_backend.create_stack(
name=stack_name,
@ -68,34 +71,37 @@ class CloudFormationResponse(BaseResponse):
role_arn=role_arn,
)
if self.request_json:
return json.dumps({
'CreateStackResponse': {
'CreateStackResult': {
'StackId': stack.stack_id,
return json.dumps(
{
"CreateStackResponse": {
"CreateStackResult": {"StackId": stack.stack_id}
}
}
})
)
else:
template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE)
return template.render(stack=stack)
@amzn_request_id
def create_change_set(self):
stack_name = self._get_param('StackName')
change_set_name = self._get_param('ChangeSetName')
stack_body = self._get_param('TemplateBody')
template_url = self._get_param('TemplateURL')
role_arn = self._get_param('RoleARN')
update_or_create = self._get_param('ChangeSetType', 'CREATE')
stack_name = self._get_param("StackName")
change_set_name = self._get_param("ChangeSetName")
stack_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
role_arn = self._get_param("RoleARN")
update_or_create = self._get_param("ChangeSetType", "CREATE")
parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
parameters = {param['parameter_key']: param['parameter_value']
for param in parameters_list}
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
parameters = {
param["parameter_key"]: param["parameter_value"]
for param in parameters_list
}
if template_url:
stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param(
'NotificationARNs.member')
stack_notification_arns = self._get_multi_param("NotificationARNs.member")
change_set_id, stack_id = self.cloudformation_backend.create_change_set(
stack_name=stack_name,
change_set_name=change_set_name,
@ -108,66 +114,64 @@ class CloudFormationResponse(BaseResponse):
change_set_type=update_or_create,
)
if self.request_json:
return json.dumps({
'CreateChangeSetResponse': {
'CreateChangeSetResult': {
'Id': change_set_id,
'StackId': stack_id,
return json.dumps(
{
"CreateChangeSetResponse": {
"CreateChangeSetResult": {
"Id": change_set_id,
"StackId": stack_id,
}
}
}
})
)
else:
template = self.response_template(CREATE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render(stack_id=stack_id, change_set_id=change_set_id)
def delete_change_set(self):
stack_name = self._get_param('StackName')
change_set_name = self._get_param('ChangeSetName')
stack_name = self._get_param("StackName")
change_set_name = self._get_param("ChangeSetName")
self.cloudformation_backend.delete_change_set(change_set_name=change_set_name, stack_name=stack_name)
self.cloudformation_backend.delete_change_set(
change_set_name=change_set_name, stack_name=stack_name
)
if self.request_json:
return json.dumps({
'DeleteChangeSetResponse': {
'DeleteChangeSetResult': {},
}
})
return json.dumps(
{"DeleteChangeSetResponse": {"DeleteChangeSetResult": {}}}
)
else:
template = self.response_template(DELETE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render()
def describe_change_set(self):
stack_name = self._get_param('StackName')
change_set_name = self._get_param('ChangeSetName')
stack_name = self._get_param("StackName")
change_set_name = self._get_param("ChangeSetName")
change_set = self.cloudformation_backend.describe_change_set(
change_set_name=change_set_name,
stack_name=stack_name,
change_set_name=change_set_name, stack_name=stack_name
)
template = self.response_template(DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render(change_set=change_set)
@amzn_request_id
def execute_change_set(self):
stack_name = self._get_param('StackName')
change_set_name = self._get_param('ChangeSetName')
stack_name = self._get_param("StackName")
change_set_name = self._get_param("ChangeSetName")
self.cloudformation_backend.execute_change_set(
stack_name=stack_name,
change_set_name=change_set_name,
stack_name=stack_name, change_set_name=change_set_name
)
if self.request_json:
return json.dumps({
'ExecuteChangeSetResponse': {
'ExecuteChangeSetResult': {},
}
})
return json.dumps(
{"ExecuteChangeSetResponse": {"ExecuteChangeSetResult": {}}}
)
else:
template = self.response_template(EXECUTE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render()
def describe_stacks(self):
stack_name_or_id = None
if self._get_param('StackName'):
stack_name_or_id = self.querystring.get('StackName')[0]
token = self._get_param('NextToken')
if self._get_param("StackName"):
stack_name_or_id = self.querystring.get("StackName")[0]
token = self._get_param("NextToken")
stacks = self.cloudformation_backend.describe_stacks(stack_name_or_id)
stack_ids = [stack.stack_id for stack in stacks]
if token:
@ -175,7 +179,7 @@ class CloudFormationResponse(BaseResponse):
else:
start = 0
max_results = 50 # using this to mske testing of paginated stacks more convenient than default 1 MB
stacks_resp = stacks[start:start + max_results]
stacks_resp = stacks[start : start + max_results]
next_token = None
if len(stacks) > (start + max_results):
next_token = stacks_resp[-1].stack_id
@ -183,9 +187,9 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacks=stacks_resp, next_token=next_token)
def describe_stack_resource(self):
stack_name = self._get_param('StackName')
stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name)
logical_resource_id = self._get_param('LogicalResourceId')
logical_resource_id = self._get_param("LogicalResourceId")
for stack_resource in stack.stack_resources:
if stack_resource.logical_resource_id == logical_resource_id:
@ -194,19 +198,18 @@ class CloudFormationResponse(BaseResponse):
else:
raise ValidationError(logical_resource_id)
template = self.response_template(
DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
template = self.response_template(DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
return template.render(stack=stack, resource=resource)
def describe_stack_resources(self):
stack_name = self._get_param('StackName')
stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name)
template = self.response_template(DESCRIBE_STACK_RESOURCES_RESPONSE)
return template.render(stack=stack)
def describe_stack_events(self):
stack_name = self._get_param('StackName')
stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name)
template = self.response_template(DESCRIBE_STACK_EVENTS_RESPONSE)
@ -223,68 +226,85 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacks=stacks)
def list_stack_resources(self):
stack_name_or_id = self._get_param('StackName')
resources = self.cloudformation_backend.list_stack_resources(
stack_name_or_id)
stack_name_or_id = self._get_param("StackName")
resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id)
if resources is None:
raise ValidationError(stack_name_or_id)
template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE)
return template.render(resources=resources)
def get_template(self):
name_or_stack_id = self.querystring.get('StackName')[0]
name_or_stack_id = self.querystring.get("StackName")[0]
stack = self.cloudformation_backend.get_stack(name_or_stack_id)
if self.request_json:
return json.dumps({
"GetTemplateResponse": {
"GetTemplateResult": {
"TemplateBody": stack.template,
"ResponseMetadata": {
"RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE"
return json.dumps(
{
"GetTemplateResponse": {
"GetTemplateResult": {
"TemplateBody": stack.template,
"ResponseMetadata": {
"RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE"
},
}
}
}
})
)
else:
template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE)
return template.render(stack=stack)
def update_stack(self):
stack_name = self._get_param('StackName')
role_arn = self._get_param('RoleARN')
template_url = self._get_param('TemplateURL')
stack_body = self._get_param('TemplateBody')
stack_name = self._get_param("StackName")
role_arn = self._get_param("RoleARN")
template_url = self._get_param("TemplateURL")
stack_body = self._get_param("TemplateBody")
stack = self.cloudformation_backend.get_stack(stack_name)
if self._get_param('UsePreviousTemplate') == "true":
if self._get_param("UsePreviousTemplate") == "true":
stack_body = stack.template
elif not stack_body and template_url:
stack_body = self._get_stack_from_s3_url(template_url)
incoming_params = self._get_list_prefix("Parameters.member")
parameters = dict([
(parameter['parameter_key'], parameter['parameter_value'])
for parameter
in incoming_params if 'parameter_value' in parameter
])
previous = dict([
(parameter['parameter_key'], stack.parameters[parameter['parameter_key']])
for parameter
in incoming_params if 'use_previous_value' in parameter
])
parameters = dict(
[
(parameter["parameter_key"], parameter["parameter_value"])
for parameter in incoming_params
if "parameter_value" in parameter
]
)
previous = dict(
[
(
parameter["parameter_key"],
stack.parameters[parameter["parameter_key"]],
)
for parameter in incoming_params
if "use_previous_value" in parameter
]
)
parameters.update(previous)
# boto3 is supposed to let you clear the tags by passing an empty value, but the request body doesn't
# end up containing anything we can use to differentiate between passing an empty value versus not
# passing anything. so until that changes, moto won't be able to clear tags, only update them.
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# so that if we don't pass the parameter, we don't clear all the tags accidentally
if not tags:
tags = None
stack = self.cloudformation_backend.get_stack(stack_name)
if stack.status == 'ROLLBACK_COMPLETE':
if stack.status == "ROLLBACK_COMPLETE":
raise ValidationError(
stack.stack_id, message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(stack.stack_id))
stack.stack_id,
message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(
stack.stack_id
),
)
stack = self.cloudformation_backend.update_stack(
name=stack_name,
@ -295,11 +315,7 @@ class CloudFormationResponse(BaseResponse):
)
if self.request_json:
stack_body = {
'UpdateStackResponse': {
'UpdateStackResult': {
'StackId': stack.name,
}
}
"UpdateStackResponse": {"UpdateStackResult": {"StackId": stack.name}}
}
return json.dumps(stack_body)
else:
@ -307,56 +323,57 @@ class CloudFormationResponse(BaseResponse):
return template.render(stack=stack)
def delete_stack(self):
name_or_stack_id = self.querystring.get('StackName')[0]
name_or_stack_id = self.querystring.get("StackName")[0]
self.cloudformation_backend.delete_stack(name_or_stack_id)
if self.request_json:
return json.dumps({
'DeleteStackResponse': {
'DeleteStackResult': {},
}
})
return json.dumps({"DeleteStackResponse": {"DeleteStackResult": {}}})
else:
template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE)
return template.render()
def list_exports(self):
token = self._get_param('NextToken')
token = self._get_param("NextToken")
exports, next_token = self.cloudformation_backend.list_exports(token=token)
template = self.response_template(LIST_EXPORTS_RESPONSE)
return template.render(exports=exports, next_token=next_token)
def validate_template(self):
cfn_lint = self.cloudformation_backend.validate_template(self._get_param('TemplateBody'))
cfn_lint = self.cloudformation_backend.validate_template(
self._get_param("TemplateBody")
)
if cfn_lint:
raise ValidationError(cfn_lint[0].message)
description = ""
try:
description = json.loads(self._get_param('TemplateBody'))['Description']
description = json.loads(self._get_param("TemplateBody"))["Description"]
except (ValueError, KeyError):
pass
try:
description = yaml.load(self._get_param('TemplateBody'))['Description']
description = yaml.load(self._get_param("TemplateBody"))["Description"]
except (yaml.ParserError, KeyError):
pass
template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE)
return template.render(description=description)
def create_stack_set(self):
stackset_name = self._get_param('StackSetName')
stack_body = self._get_param('TemplateBody')
template_url = self._get_param('TemplateURL')
stackset_name = self._get_param("StackSetName")
stack_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
# role_arn = self._get_param('RoleARN')
parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# Copy-Pasta - Hack dict-comprehension
parameters = dict([
(parameter['parameter_key'], parameter['parameter_value'])
for parameter
in parameters_list
])
parameters = dict(
[
(parameter["parameter_key"], parameter["parameter_value"])
for parameter in parameters_list
]
)
if template_url:
stack_body = self._get_stack_from_s3_url(template_url)
@ -368,59 +385,67 @@ class CloudFormationResponse(BaseResponse):
# role_arn=role_arn,
)
if self.request_json:
return json.dumps({
'CreateStackSetResponse': {
'CreateStackSetResult': {
'StackSetId': stackset.stackset_id,
return json.dumps(
{
"CreateStackSetResponse": {
"CreateStackSetResult": {"StackSetId": stackset.stackset_id}
}
}
})
)
else:
template = self.response_template(CREATE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(stackset=stackset)
def create_stack_instances(self):
stackset_name = self._get_param('StackSetName')
accounts = self._get_multi_param('Accounts.member')
regions = self._get_multi_param('Regions.member')
parameters = self._get_multi_param('ParameterOverrides.member')
self.cloudformation_backend.create_stack_instances(stackset_name, accounts, regions, parameters)
stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param("Regions.member")
parameters = self._get_multi_param("ParameterOverrides.member")
self.cloudformation_backend.create_stack_instances(
stackset_name, accounts, regions, parameters
)
template = self.response_template(CREATE_STACK_INSTANCES_TEMPLATE)
return template.render()
def delete_stack_set(self):
stackset_name = self._get_param('StackSetName')
stackset_name = self._get_param("StackSetName")
self.cloudformation_backend.delete_stack_set(stackset_name)
template = self.response_template(DELETE_STACK_SET_RESPONSE_TEMPLATE)
return template.render()
def delete_stack_instances(self):
stackset_name = self._get_param('StackSetName')
accounts = self._get_multi_param('Accounts.member')
regions = self._get_multi_param('Regions.member')
operation = self.cloudformation_backend.delete_stack_instances(stackset_name, accounts, regions)
stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param("Regions.member")
operation = self.cloudformation_backend.delete_stack_instances(
stackset_name, accounts, regions
)
template = self.response_template(DELETE_STACK_INSTANCES_TEMPLATE)
return template.render(operation=operation)
def describe_stack_set(self):
stackset_name = self._get_param('StackSetName')
stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
if not stackset.admin_role:
stackset.admin_role = 'arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole'
stackset.admin_role = "arn:aws:iam::{AccountId}:role/AWSCloudFormationStackSetAdministrationRole".format(
AccountId=ACCOUNT_ID
)
if not stackset.execution_role:
stackset.execution_role = 'AWSCloudFormationStackSetExecutionRole'
stackset.execution_role = "AWSCloudFormationStackSetExecutionRole"
template = self.response_template(DESCRIBE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(stackset=stackset)
def describe_stack_instance(self):
stackset_name = self._get_param('StackSetName')
account = self._get_param('StackInstanceAccount')
region = self._get_param('StackInstanceRegion')
stackset_name = self._get_param("StackSetName")
account = self._get_param("StackInstanceAccount")
region = self._get_param("StackInstanceRegion")
instance = self.cloudformation_backend.get_stack_set(stackset_name).instances.get_instance(account, region)
instance = self.cloudformation_backend.get_stack_set(
stackset_name
).instances.get_instance(account, region)
template = self.response_template(DESCRIBE_STACK_INSTANCE_TEMPLATE)
rendered = template.render(instance=instance)
return rendered
@ -431,61 +456,66 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacksets=stacksets)
def list_stack_instances(self):
stackset_name = self._get_param('StackSetName')
stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
template = self.response_template(LIST_STACK_INSTANCES_TEMPLATE)
return template.render(stackset=stackset)
def list_stack_set_operations(self):
stackset_name = self._get_param('StackSetName')
stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
template = self.response_template(LIST_STACK_SET_OPERATIONS_RESPONSE_TEMPLATE)
return template.render(stackset=stackset)
def stop_stack_set_operation(self):
stackset_name = self._get_param('StackSetName')
operation_id = self._get_param('OperationId')
stackset_name = self._get_param("StackSetName")
operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
stackset.update_operation(operation_id, 'STOPPED')
stackset.update_operation(operation_id, "STOPPED")
template = self.response_template(STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE)
return template.render()
def describe_stack_set_operation(self):
stackset_name = self._get_param('StackSetName')
operation_id = self._get_param('OperationId')
stackset_name = self._get_param("StackSetName")
operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
operation = stackset.get_operation(operation_id)
template = self.response_template(DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE)
return template.render(stackset=stackset, operation=operation)
def list_stack_set_operation_results(self):
stackset_name = self._get_param('StackSetName')
operation_id = self._get_param('OperationId')
stackset_name = self._get_param("StackSetName")
operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
operation = stackset.get_operation(operation_id)
template = self.response_template(LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE)
template = self.response_template(
LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE
)
return template.render(operation=operation)
def update_stack_set(self):
stackset_name = self._get_param('StackSetName')
operation_id = self._get_param('OperationId')
description = self._get_param('Description')
execution_role = self._get_param('ExecutionRoleName')
admin_role = self._get_param('AdministrationRoleARN')
accounts = self._get_multi_param('Accounts.member')
regions = self._get_multi_param('Regions.member')
template_body = self._get_param('TemplateBody')
template_url = self._get_param('TemplateURL')
stackset_name = self._get_param("StackSetName")
operation_id = self._get_param("OperationId")
description = self._get_param("Description")
execution_role = self._get_param("ExecutionRoleName")
admin_role = self._get_param("AdministrationRoleARN")
accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param("Regions.member")
template_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
if template_url:
template_body = self._get_stack_from_s3_url(template_url)
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
parameters_list = self._get_list_prefix("Parameters.member")
parameters = dict([
(parameter['parameter_key'], parameter['parameter_value'])
for parameter
in parameters_list
])
parameters = dict(
[
(parameter["parameter_key"], parameter["parameter_value"])
for parameter in parameters_list
]
)
operation = self.cloudformation_backend.update_stack_set(
stackset_name=stackset_name,
template=template_body,
@ -496,18 +526,20 @@ class CloudFormationResponse(BaseResponse):
execution_role=execution_role,
accounts=accounts,
regions=regions,
operation_id=operation_id
operation_id=operation_id,
)
template = self.response_template(UPDATE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(operation=operation)
def update_stack_instances(self):
stackset_name = self._get_param('StackSetName')
accounts = self._get_multi_param('Accounts.member')
regions = self._get_multi_param('Regions.member')
parameters = self._get_multi_param('ParameterOverrides.member')
operation = self.cloudformation_backend.get_stack_set(stackset_name).update_instances(accounts, regions, parameters)
stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param("Regions.member")
parameters = self._get_multi_param("ParameterOverrides.member")
operation = self.cloudformation_backend.get_stack_set(
stackset_name
).update_instances(accounts, regions, parameters)
template = self.response_template(UPDATE_STACK_INSTANCES_RESPONSE_TEMPLATE)
return template.render(operation=operation)
@ -630,7 +662,7 @@ DESCRIBE_STACKS_TEMPLATE = """<DescribeStacksResponse>
<member>
<StackName>{{ stack.name }}</StackName>
<StackId>{{ stack.stack_id }}</StackId>
<CreationTime>2010-07-27T22:28:28Z</CreationTime>
<CreationTime>{{ stack.creation_time_iso_8601 }}</CreationTime>
<StackStatus>{{ stack.status }}</StackStatus>
{% if stack.notification_arns %}
<NotificationARNs>
@ -771,7 +803,7 @@ LIST_STACKS_RESPONSE = """<ListStacksResponse>
<StackId>{{ stack.stack_id }}</StackId>
<StackStatus>{{ stack.status }}</StackStatus>
<StackName>{{ stack.name }}</StackName>
<CreationTime>2011-05-23T15:47:44Z</CreationTime>
<CreationTime>{{ stack.creation_time_iso_8601 }}</CreationTime>
<TemplateDescription>{{ stack.description }}</TemplateDescription>
</member>
{% endfor %}
@ -1025,11 +1057,14 @@ STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE = """<StopStackSetOperationResponse x
</ResponseMetadata> </StopStackSetOperationResponse>
"""
DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = """<DescribeStackSetOperationResponse xmlns="http://internal.amazon.com/coral/com.amazonaws.maestro.service.v20160713/">
DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = (
"""<DescribeStackSetOperationResponse xmlns="http://internal.amazon.com/coral/com.amazonaws.maestro.service.v20160713/">
<DescribeStackSetOperationResult>
<StackSetOperation>
<ExecutionRoleName>{{ stackset.execution_role }}</ExecutionRoleName>
<AdministrationRoleARN>arn:aws:iam::123456789012:role/{{ stackset.admin_role }}</AdministrationRoleARN>
<AdministrationRoleARN>arn:aws:iam::"""
+ ACCOUNT_ID
+ """:role/{{ stackset.admin_role }}</AdministrationRoleARN>
<StackSetId>{{ stackset.id }}</StackSetId>
<CreationTimestamp>{{ operation.CreationTimestamp }}</CreationTimestamp>
<OperationId>{{ operation.OperationId }}</OperationId>
@ -1046,15 +1081,19 @@ DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = """<DescribeStackSetOperationRes
</ResponseMetadata>
</DescribeStackSetOperationResponse>
"""
)
LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = """<ListStackSetOperationResultsResponse xmlns="http://internal.amazon.com/coral/com.amazonaws.maestro.service.v20160713/">
LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = (
"""<ListStackSetOperationResultsResponse xmlns="http://internal.amazon.com/coral/com.amazonaws.maestro.service.v20160713/">
<ListStackSetOperationResultsResult>
<Summaries>
{% for instance in operation.Instances %}
{% for account, region in instance.items() %}
<member>
<AccountGateResult>
<StatusReason>Function not found: arn:aws:lambda:us-west-2:123456789012:function:AWSCloudFormationStackSetAccountGate</StatusReason>
<StatusReason>Function not found: arn:aws:lambda:us-west-2:"""
+ ACCOUNT_ID
+ """:function:AWSCloudFormationStackSetAccountGate</StatusReason>
<Status>SKIPPED</Status>
</AccountGateResult>
<Region>{{ region }}</Region>
@ -1070,3 +1109,4 @@ LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = """<ListStackSetOperationRe
</ResponseMetadata>
</ListStackSetOperationResultsResponse>
"""
)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import CloudFormationResponse
url_bases = [
"https?://cloudformation.(.+).amazonaws.com",
]
url_bases = ["https?://cloudformation.(.+).amazonaws.com"]
url_paths = {
'{0}/$': CloudFormationResponse.dispatch,
}
url_paths = {"{0}/$": CloudFormationResponse.dispatch}

View File

@ -7,48 +7,56 @@ import os
import string
from cfnlint import decode, core
from moto.core import ACCOUNT_ID
def generate_stack_id(stack_name, region="us-east-1", account="123456789"):
random_id = uuid.uuid4()
return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(region, account, stack_name, random_id)
return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(
region, account, stack_name, random_id
)
def generate_changeset_id(changeset_name, region_name):
random_id = uuid.uuid4()
return 'arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}'.format(region_name, changeset_name, random_id)
return "arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}".format(
region_name, changeset_name, random_id
)
def generate_stackset_id(stackset_name):
random_id = uuid.uuid4()
return '{}:{}'.format(stackset_name, random_id)
return "{}:{}".format(stackset_name, random_id)
def generate_stackset_arn(stackset_id, region_name):
return 'arn:aws:cloudformation:{}:123456789012:stackset/{}'.format(region_name, stackset_id)
return "arn:aws:cloudformation:{}:{}:stackset/{}".format(
region_name, ACCOUNT_ID, stackset_id
)
def random_suffix():
size = 12
chars = list(range(10)) + list(string.ascii_uppercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size))
return "".join(six.text_type(random.choice(chars)) for x in range(size))
def yaml_tag_constructor(loader, tag, node):
"""convert shorthand intrinsic function to full name
"""
def _f(loader, tag, node):
if tag == '!GetAtt':
return node.value.split('.')
if tag == "!GetAtt":
return node.value.split(".")
elif type(node) == yaml.SequenceNode:
return loader.construct_sequence(node)
else:
return node.value
if tag == '!Ref':
key = 'Ref'
if tag == "!Ref":
key = "Ref"
else:
key = 'Fn::{}'.format(tag[1:])
key = "Fn::{}".format(tag[1:])
return {key: _f(loader, tag, node)}
@ -71,13 +79,9 @@ def validate_template_cfn_lint(template):
rules = core.get_rules([], [], [])
# Use us-east-1 region (spec file) for validation
regions = ['us-east-1']
regions = ["us-east-1"]
# Process all the rules and gather the errors
matches = core.run_checks(
abs_filename,
template,
rules,
regions)
matches = core.run_checks(abs_filename, template, rules, regions)
return matches

View File

@ -1,6 +1,6 @@
from .models import cloudwatch_backends
from ..core.models import base_decorator, deprecated_base_decorator
cloudwatch_backend = cloudwatch_backends['us-east-1']
cloudwatch_backend = cloudwatch_backends["us-east-1"]
mock_cloudwatch = base_decorator(cloudwatch_backends)
mock_cloudwatch_deprecated = deprecated_base_decorator(cloudwatch_backends)

View File

@ -1,20 +1,23 @@
import json
from moto.core.utils import iso_8601_datetime_with_milliseconds
from boto3 import Session
from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError
import boto.ec2.cloudwatch
from moto.logs import logs_backends
from datetime import datetime, timedelta
from dateutil.tz import tzutc
from uuid import uuid4
from .utils import make_arn_for_dashboard
from dateutil import parser
DEFAULT_ACCOUNT_ID = 123456789012
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
_EMPTY_LIST = tuple()
class Dimension(object):
def __init__(self, name, value):
self.name = name
self.value = value
@ -49,10 +52,24 @@ def daterange(start, stop, step=timedelta(days=1), inclusive=False):
class FakeAlarm(BaseModel):
def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods,
period, threshold, statistic, description, dimensions, alarm_actions,
ok_actions, insufficient_data_actions, unit):
def __init__(
self,
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
actions_enabled,
):
self.name = name
self.namespace = namespace
self.metric_name = metric_name
@ -62,8 +79,10 @@ class FakeAlarm(BaseModel):
self.threshold = threshold
self.statistic = statistic
self.description = description
self.dimensions = [Dimension(dimension['name'], dimension[
'value']) for dimension in dimensions]
self.dimensions = [
Dimension(dimension["name"], dimension["value"]) for dimension in dimensions
]
self.actions_enabled = actions_enabled
self.alarm_actions = alarm_actions
self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions
@ -72,15 +91,21 @@ class FakeAlarm(BaseModel):
self.history = []
self.state_reason = ''
self.state_reason_data = '{}'
self.state_value = 'OK'
self.state_reason = ""
self.state_reason_data = "{}"
self.state_value = "OK"
self.state_updated_timestamp = datetime.utcnow()
def update_state(self, reason, reason_data, state_value):
# History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action
self.history.append(
('StateUpdate', self.state_reason, self.state_reason_data, self.state_value, self.state_updated_timestamp)
(
"StateUpdate",
self.state_reason,
self.state_reason_data,
self.state_value,
self.state_updated_timestamp,
)
)
self.state_reason = reason
@ -90,14 +115,14 @@ class FakeAlarm(BaseModel):
class MetricDatum(BaseModel):
def __init__(self, namespace, name, value, dimensions, timestamp):
self.namespace = namespace
self.name = name
self.value = value
self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc())
self.dimensions = [Dimension(dimension['Name'], dimension[
'Value']) for dimension in dimensions]
self.dimensions = [
Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions
]
class Dashboard(BaseModel):
@ -120,18 +145,18 @@ class Dashboard(BaseModel):
return len(self.body)
def __repr__(self):
return '<CloudWatchDashboard {0}>'.format(self.name)
return "<CloudWatchDashboard {0}>".format(self.name)
class Statistics:
def __init__(self, stats, dt):
self.timestamp = iso_8601_datetime_with_milliseconds(dt)
self.timestamp = iso_8601_datetime_without_milliseconds(dt)
self.values = []
self.stats = stats
@property
def sample_count(self):
if 'SampleCount' not in self.stats:
if "SampleCount" not in self.stats:
return None
return len(self.values)
@ -142,28 +167,28 @@ class Statistics:
@property
def sum(self):
if 'Sum' not in self.stats:
if "Sum" not in self.stats:
return None
return sum(self.values)
@property
def minimum(self):
if 'Minimum' not in self.stats:
if "Minimum" not in self.stats:
return None
return min(self.values)
@property
def maximum(self):
if 'Maximum' not in self.stats:
if "Maximum" not in self.stats:
return None
return max(self.values)
@property
def average(self):
if 'Average' not in self.stats:
if "Average" not in self.stats:
return None
# when moto is 3.4+ we can switch to the statistics module
@ -171,18 +196,47 @@ class Statistics:
class CloudWatchBackend(BaseBackend):
def __init__(self):
self.alarms = {}
self.dashboards = {}
self.metric_data = []
self.paged_metric_data = {}
def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods,
period, threshold, statistic, description, dimensions,
alarm_actions, ok_actions, insufficient_data_actions, unit):
alarm = FakeAlarm(name, namespace, metric_name, comparison_operator, evaluation_periods, period,
threshold, statistic, description, dimensions, alarm_actions,
ok_actions, insufficient_data_actions, unit)
def put_metric_alarm(
self,
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
actions_enabled,
):
alarm = FakeAlarm(
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
actions_enabled,
)
self.alarms[name] = alarm
return alarm
@ -214,14 +268,12 @@ class CloudWatchBackend(BaseBackend):
]
def get_alarms_by_alarm_names(self, alarm_names):
return [
alarm
for alarm in self.alarms.values()
if alarm.name in alarm_names
]
return [alarm for alarm in self.alarms.values() if alarm.name in alarm_names]
def get_alarms_by_state_value(self, target_state):
return filter(lambda alarm: alarm.state_value == target_state, self.alarms.values())
return filter(
lambda alarm: alarm.state_value == target_state, self.alarms.values()
)
def delete_alarms(self, alarm_names):
for alarm_name in alarm_names:
@ -230,17 +282,30 @@ class CloudWatchBackend(BaseBackend):
def put_metric_data(self, namespace, metric_data):
for metric_member in metric_data:
# Preserve "datetime" for get_metric_statistics comparisons
timestamp = metric_member.get('Timestamp')
timestamp = metric_member.get("Timestamp")
if timestamp is not None and type(timestamp) != datetime:
timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ')
timestamp = timestamp.replace(tzinfo=tzutc())
self.metric_data.append(MetricDatum(
namespace, metric_member['MetricName'], float(metric_member.get('Value', 0)), metric_member.get('Dimensions.member', _EMPTY_LIST), timestamp))
timestamp = parser.parse(timestamp)
self.metric_data.append(
MetricDatum(
namespace,
metric_member["MetricName"],
float(metric_member.get("Value", 0)),
metric_member.get("Dimensions.member", _EMPTY_LIST),
timestamp,
)
)
def get_metric_statistics(self, namespace, metric_name, start_time, end_time, period, stats):
def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats
):
period_delta = timedelta(seconds=period)
filtered_data = [md for md in self.metric_data if
md.namespace == namespace and md.name == metric_name and start_time <= md.timestamp <= end_time]
filtered_data = [
md
for md in self.metric_data
if md.namespace == namespace
and md.name == metric_name
and start_time <= md.timestamp <= end_time
]
# earliest to oldest
filtered_data = sorted(filtered_data, key=lambda x: x.timestamp)
@ -249,9 +314,15 @@ class CloudWatchBackend(BaseBackend):
idx = 0
data = list()
for dt in daterange(filtered_data[0].timestamp, filtered_data[-1].timestamp + period_delta, period_delta):
for dt in daterange(
filtered_data[0].timestamp,
filtered_data[-1].timestamp + period_delta,
period_delta,
):
s = Statistics(stats, dt)
while idx < len(filtered_data) and filtered_data[idx].timestamp < (dt + period_delta):
while idx < len(filtered_data) and filtered_data[idx].timestamp < (
dt + period_delta
):
s.values.append(filtered_data[idx].value)
idx += 1
@ -268,7 +339,7 @@ class CloudWatchBackend(BaseBackend):
def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body)
def list_dashboards(self, prefix=''):
def list_dashboards(self, prefix=""):
for key, value in self.dashboards.items():
if key.startswith(prefix):
yield value
@ -280,7 +351,12 @@ class CloudWatchBackend(BaseBackend):
left_over = to_delete - all_dashboards
if len(left_over) > 0:
# Some dashboards are not found
return False, 'The specified dashboard does not exist. [{0}]'.format(', '.join(left_over))
return (
False,
"The specified dashboard does not exist. [{0}]".format(
", ".join(left_over)
),
)
for dashboard in to_delete:
del self.dashboards[dashboard]
@ -295,38 +371,77 @@ class CloudWatchBackend(BaseBackend):
if reason_data is not None:
json.loads(reason_data)
except ValueError:
raise RESTError('InvalidFormat', 'StateReasonData is invalid JSON')
raise RESTError("InvalidFormat", "StateReasonData is invalid JSON")
if alarm_name not in self.alarms:
raise RESTError('ResourceNotFound', 'Alarm {0} not found'.format(alarm_name), status=404)
raise RESTError(
"ResourceNotFound", "Alarm {0} not found".format(alarm_name), status=404
)
if state_value not in ('OK', 'ALARM', 'INSUFFICIENT_DATA'):
raise RESTError('InvalidParameterValue', 'StateValue is not one of OK | ALARM | INSUFFICIENT_DATA')
if state_value not in ("OK", "ALARM", "INSUFFICIENT_DATA"):
raise RESTError(
"InvalidParameterValue",
"StateValue is not one of OK | ALARM | INSUFFICIENT_DATA",
)
self.alarms[alarm_name].update_state(reason, reason_data, state_value)
def list_metrics(self, next_token, namespace, metric_name):
if next_token:
if next_token not in self.paged_metric_data:
raise RESTError(
"PaginationException", "Request parameter NextToken is invalid"
)
else:
metrics = self.paged_metric_data[next_token]
del self.paged_metric_data[next_token] # Cant reuse same token twice
return self._get_paginated(metrics)
else:
metrics = self.get_filtered_metrics(metric_name, namespace)
return self._get_paginated(metrics)
def get_filtered_metrics(self, metric_name, namespace):
metrics = self.get_all_metrics()
if namespace:
metrics = [md for md in metrics if md.namespace == namespace]
if metric_name:
metrics = [md for md in metrics if md.name == metric_name]
return metrics
def _get_paginated(self, metrics):
if len(metrics) > 500:
next_token = str(uuid4())
self.paged_metric_data[next_token] = metrics[500:]
return next_token, metrics[0:500]
else:
return None, metrics
class LogGroup(BaseModel):
def __init__(self, spec):
# required
self.name = spec['LogGroupName']
self.name = spec["LogGroupName"]
# optional
self.tags = spec.get('Tags', [])
self.tags = spec.get("Tags", [])
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
spec = {
'LogGroupName': properties['LogGroupName']
}
optional_properties = 'Tags'.split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]
return LogGroup(spec)
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
log_group_name = properties["LogGroupName"]
tags = properties.get("Tags", {})
return logs_backends[region_name].create_log_group(
log_group_name, tags, **properties
)
cloudwatch_backends = {}
for region in boto.ec2.cloudwatch.regions():
cloudwatch_backends[region.name] = CloudWatchBackend()
for region in Session().get_available_regions("cloudwatch"):
cloudwatch_backends[region] = CloudWatchBackend()
for region in Session().get_available_regions(
"cloudwatch", partition_name="aws-us-gov"
):
cloudwatch_backends[region] = CloudWatchBackend()
for region in Session().get_available_regions("cloudwatch", partition_name="aws-cn"):
cloudwatch_backends[region] = CloudWatchBackend()

View File

@ -6,7 +6,6 @@ from dateutil.parser import parse as dtparse
class CloudWatchResponse(BaseResponse):
@property
def cloudwatch_backend(self):
return cloudwatch_backends[self.region]
@ -17,45 +16,56 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def put_metric_alarm(self):
name = self._get_param('AlarmName')
namespace = self._get_param('Namespace')
metric_name = self._get_param('MetricName')
comparison_operator = self._get_param('ComparisonOperator')
evaluation_periods = self._get_param('EvaluationPeriods')
period = self._get_param('Period')
threshold = self._get_param('Threshold')
statistic = self._get_param('Statistic')
description = self._get_param('AlarmDescription')
dimensions = self._get_list_prefix('Dimensions.member')
alarm_actions = self._get_multi_param('AlarmActions.member')
ok_actions = self._get_multi_param('OKActions.member')
name = self._get_param("AlarmName")
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
comparison_operator = self._get_param("ComparisonOperator")
evaluation_periods = self._get_param("EvaluationPeriods")
period = self._get_param("Period")
threshold = self._get_param("Threshold")
statistic = self._get_param("Statistic")
description = self._get_param("AlarmDescription")
dimensions = self._get_list_prefix("Dimensions.member")
alarm_actions = self._get_multi_param("AlarmActions.member")
ok_actions = self._get_multi_param("OKActions.member")
actions_enabled = self._get_param("ActionsEnabled")
insufficient_data_actions = self._get_multi_param(
"InsufficientDataActions.member")
unit = self._get_param('Unit')
alarm = self.cloudwatch_backend.put_metric_alarm(name, namespace, metric_name,
comparison_operator,
evaluation_periods, period,
threshold, statistic,
description, dimensions,
alarm_actions, ok_actions,
insufficient_data_actions,
unit)
"InsufficientDataActions.member"
)
unit = self._get_param("Unit")
alarm = self.cloudwatch_backend.put_metric_alarm(
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
actions_enabled,
)
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm)
@amzn_request_id
def describe_alarms(self):
action_prefix = self._get_param('ActionPrefix')
alarm_name_prefix = self._get_param('AlarmNamePrefix')
alarm_names = self._get_multi_param('AlarmNames.member')
state_value = self._get_param('StateValue')
action_prefix = self._get_param("ActionPrefix")
alarm_name_prefix = self._get_param("AlarmNamePrefix")
alarm_names = self._get_multi_param("AlarmNames.member")
state_value = self._get_param("StateValue")
if action_prefix:
alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(
action_prefix)
alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(action_prefix)
elif alarm_name_prefix:
alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix(
alarm_name_prefix)
alarm_name_prefix
)
elif alarm_names:
alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names)
elif state_value:
@ -68,15 +78,15 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def delete_alarms(self):
alarm_names = self._get_multi_param('AlarmNames.member')
alarm_names = self._get_multi_param("AlarmNames.member")
self.cloudwatch_backend.delete_alarms(alarm_names)
template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE)
return template.render()
@amzn_request_id
def put_metric_data(self):
namespace = self._get_param('Namespace')
metric_data = self._get_multi_param('MetricData.member')
namespace = self._get_param("Namespace")
metric_data = self._get_multi_param("MetricData.member")
self.cloudwatch_backend.put_metric_data(namespace, metric_data)
template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
@ -84,43 +94,52 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def get_metric_statistics(self):
namespace = self._get_param('Namespace')
metric_name = self._get_param('MetricName')
start_time = dtparse(self._get_param('StartTime'))
end_time = dtparse(self._get_param('EndTime'))
period = int(self._get_param('Period'))
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
start_time = dtparse(self._get_param("StartTime"))
end_time = dtparse(self._get_param("EndTime"))
period = int(self._get_param("Period"))
statistics = self._get_multi_param("Statistics.member")
# Unsupported Parameters (To Be Implemented)
unit = self._get_param('Unit')
extended_statistics = self._get_param('ExtendedStatistics')
dimensions = self._get_param('Dimensions')
unit = self._get_param("Unit")
extended_statistics = self._get_param("ExtendedStatistics")
dimensions = self._get_param("Dimensions")
if unit or extended_statistics or dimensions:
raise NotImplemented()
raise NotImplementedError()
# TODO: this should instead throw InvalidParameterCombination
if not statistics:
raise NotImplemented("Must specify either Statistics or ExtendedStatistics")
raise NotImplementedError(
"Must specify either Statistics or ExtendedStatistics"
)
datapoints = self.cloudwatch_backend.get_metric_statistics(namespace, metric_name, start_time, end_time, period, statistics)
datapoints = self.cloudwatch_backend.get_metric_statistics(
namespace, metric_name, start_time, end_time, period, statistics
)
template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE)
return template.render(label=metric_name, datapoints=datapoints)
@amzn_request_id
def list_metrics(self):
metrics = self.cloudwatch_backend.get_all_metrics()
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
next_token = self._get_param("NextToken")
next_token, metrics = self.cloudwatch_backend.list_metrics(
next_token, namespace, metric_name
)
template = self.response_template(LIST_METRICS_TEMPLATE)
return template.render(metrics=metrics)
return template.render(metrics=metrics, next_token=next_token)
@amzn_request_id
def delete_dashboards(self):
dashboards = self._get_multi_param('DashboardNames.member')
dashboards = self._get_multi_param("DashboardNames.member")
if dashboards is None:
return self._error('InvalidParameterValue', 'Need at least 1 dashboard')
return self._error("InvalidParameterValue", "Need at least 1 dashboard")
status, error = self.cloudwatch_backend.delete_dashboards(dashboards)
if not status:
return self._error('ResourceNotFound', error)
return self._error("ResourceNotFound", error)
template = self.response_template(DELETE_DASHBOARD_TEMPLATE)
return template.render()
@ -143,18 +162,18 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def get_dashboard(self):
dashboard_name = self._get_param('DashboardName')
dashboard_name = self._get_param("DashboardName")
dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name)
if dashboard is None:
return self._error('ResourceNotFound', 'Dashboard does not exist')
return self._error("ResourceNotFound", "Dashboard does not exist")
template = self.response_template(GET_DASHBOARD_TEMPLATE)
return template.render(dashboard=dashboard)
@amzn_request_id
def list_dashboards(self):
prefix = self._get_param('DashboardNamePrefix', '')
prefix = self._get_param("DashboardNamePrefix", "")
dashboards = self.cloudwatch_backend.list_dashboards(prefix)
@ -163,13 +182,13 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def put_dashboard(self):
name = self._get_param('DashboardName')
body = self._get_param('DashboardBody')
name = self._get_param("DashboardName")
body = self._get_param("DashboardBody")
try:
json.loads(body)
except ValueError:
return self._error('InvalidParameterInput', 'Body is invalid JSON')
return self._error("InvalidParameterInput", "Body is invalid JSON")
self.cloudwatch_backend.put_dashboard(name, body)
@ -178,12 +197,14 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def set_alarm_state(self):
alarm_name = self._get_param('AlarmName')
reason = self._get_param('StateReason')
reason_data = self._get_param('StateReasonData')
state_value = self._get_param('StateValue')
alarm_name = self._get_param("AlarmName")
reason = self._get_param("StateReason")
reason_data = self._get_param("StateReasonData")
state_value = self._get_param("StateValue")
self.cloudwatch_backend.set_alarm_state(alarm_name, reason, reason_data, state_value)
self.cloudwatch_backend.set_alarm_state(
alarm_name, reason, reason_data, state_value
)
template = self.response_template(SET_ALARM_STATE_TEMPLATE)
return template.render()
@ -326,9 +347,11 @@ LIST_METRICS_TEMPLATE = """<ListMetricsResponse xmlns="http://monitoring.amazona
</member>
{% endfor %}
</Metrics>
{% if next_token is not none %}
<NextToken>
96e88479-4662-450b-8a13-239ded6ce9fe
{{ next_token }}
</NextToken>
{% endif %}
</ListMetricsResult>
</ListMetricsResponse>"""

View File

@ -1,9 +1,5 @@
from .responses import CloudWatchResponse
url_bases = [
"https?://monitoring.(.+).amazonaws.com",
]
url_bases = ["https?://monitoring.(.+).amazonaws.com"]
url_paths = {
'{0}/$': CloudWatchResponse.dispatch,
}
url_paths = {"{0}/$": CloudWatchResponse.dispatch}

View File

@ -0,0 +1,4 @@
from .models import codecommit_backends
from ..core.models import base_decorator
mock_codecommit = base_decorator(codecommit_backends)

View File

@ -0,0 +1,35 @@
from moto.core.exceptions import JsonRESTError
class RepositoryNameExistsException(JsonRESTError):
code = 400
def __init__(self, repository_name):
super(RepositoryNameExistsException, self).__init__(
"RepositoryNameExistsException",
"Repository named {0} already exists".format(repository_name),
)
class RepositoryDoesNotExistException(JsonRESTError):
code = 400
def __init__(self, repository_name):
super(RepositoryDoesNotExistException, self).__init__(
"RepositoryDoesNotExistException",
"{0} does not exist".format(repository_name),
)
class InvalidRepositoryNameException(JsonRESTError):
code = 400
def __init__(self):
super(InvalidRepositoryNameException, self).__init__(
"InvalidRepositoryNameException",
"The repository name is not valid. Repository names can be any valid "
"combination of letters, numbers, "
"periods, underscores, and dashes between 1 and 100 characters in "
"length. Names are case sensitive. "
"For more information, see Limits in the AWS CodeCommit User Guide. ",
)

69
moto/codecommit/models.py Normal file
View File

@ -0,0 +1,69 @@
from boto3 import Session
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from datetime import datetime
from moto.iam.models import ACCOUNT_ID
from .exceptions import RepositoryDoesNotExistException, RepositoryNameExistsException
import uuid
class CodeCommit(BaseModel):
def __init__(self, region, repository_description, repository_name):
current_date = iso_8601_datetime_with_milliseconds(datetime.utcnow())
self.repository_metadata = dict()
self.repository_metadata["repositoryName"] = repository_name
self.repository_metadata[
"cloneUrlSsh"
] = "ssh://git-codecommit.{0}.amazonaws.com/v1/repos/{1}".format(
region, repository_name
)
self.repository_metadata[
"cloneUrlHttp"
] = "https://git-codecommit.{0}.amazonaws.com/v1/repos/{1}".format(
region, repository_name
)
self.repository_metadata["creationDate"] = current_date
self.repository_metadata["lastModifiedDate"] = current_date
self.repository_metadata["repositoryDescription"] = repository_description
self.repository_metadata["repositoryId"] = str(uuid.uuid4())
self.repository_metadata["Arn"] = "arn:aws:codecommit:{0}:{1}:{2}".format(
region, ACCOUNT_ID, repository_name
)
self.repository_metadata["accountId"] = ACCOUNT_ID
class CodeCommitBackend(BaseBackend):
def __init__(self):
self.repositories = {}
def create_repository(self, region, repository_name, repository_description):
repository = self.repositories.get(repository_name)
if repository:
raise RepositoryNameExistsException(repository_name)
self.repositories[repository_name] = CodeCommit(
region, repository_description, repository_name
)
return self.repositories[repository_name].repository_metadata
def get_repository(self, repository_name):
repository = self.repositories.get(repository_name)
if not repository:
raise RepositoryDoesNotExistException(repository_name)
return repository.repository_metadata
def delete_repository(self, repository_name):
repository = self.repositories.get(repository_name)
if repository:
self.repositories.pop(repository_name)
return repository.repository_metadata.get("repositoryId")
return None
codecommit_backends = {}
for region in Session().get_available_regions("codecommit"):
codecommit_backends[region] = CodeCommitBackend()

View File

@ -0,0 +1,57 @@
import json
import re
from moto.core.responses import BaseResponse
from .models import codecommit_backends
from .exceptions import InvalidRepositoryNameException
def _is_repository_name_valid(repository_name):
name_regex = re.compile(r"[\w\.-]+")
result = name_regex.split(repository_name)
if len(result) > 0:
for match in result:
if len(match) > 0:
return False
return True
class CodeCommitResponse(BaseResponse):
@property
def codecommit_backend(self):
return codecommit_backends[self.region]
def create_repository(self):
if not _is_repository_name_valid(self._get_param("repositoryName")):
raise InvalidRepositoryNameException()
repository_metadata = self.codecommit_backend.create_repository(
self.region,
self._get_param("repositoryName"),
self._get_param("repositoryDescription"),
)
return json.dumps({"repositoryMetadata": repository_metadata})
def get_repository(self):
if not _is_repository_name_valid(self._get_param("repositoryName")):
raise InvalidRepositoryNameException()
repository_metadata = self.codecommit_backend.get_repository(
self._get_param("repositoryName")
)
return json.dumps({"repositoryMetadata": repository_metadata})
def delete_repository(self):
if not _is_repository_name_valid(self._get_param("repositoryName")):
raise InvalidRepositoryNameException()
repository_id = self.codecommit_backend.delete_repository(
self._get_param("repositoryName")
)
if repository_id:
return json.dumps({"repositoryId": repository_id})
return json.dumps({})

6
moto/codecommit/urls.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .responses import CodeCommitResponse
url_bases = ["https?://codecommit.(.+).amazonaws.com"]
url_paths = {"{0}/$": CodeCommitResponse.dispatch}

View File

@ -0,0 +1,4 @@
from .models import codepipeline_backends
from ..core.models import base_decorator
mock_codepipeline = base_decorator(codepipeline_backends)

View File

@ -0,0 +1,44 @@
from moto.core.exceptions import JsonRESTError
class InvalidStructureException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidStructureException, self).__init__(
"InvalidStructureException", message
)
class PipelineNotFoundException(JsonRESTError):
code = 400
def __init__(self, message):
super(PipelineNotFoundException, self).__init__(
"PipelineNotFoundException", message
)
class ResourceNotFoundException(JsonRESTError):
code = 400
def __init__(self, message):
super(ResourceNotFoundException, self).__init__(
"ResourceNotFoundException", message
)
class InvalidTagsException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidTagsException, self).__init__("InvalidTagsException", message)
class TooManyTagsException(JsonRESTError):
code = 400
def __init__(self, arn):
super(TooManyTagsException, self).__init__(
"TooManyTagsException", "Tag limit exceeded for resource [{}].".format(arn)
)

218
moto/codepipeline/models.py Normal file
View File

@ -0,0 +1,218 @@
import json
from datetime import datetime
from boto3 import Session
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.iam.exceptions import IAMNotFoundException
from moto.iam import iam_backends
from moto.codepipeline.exceptions import (
InvalidStructureException,
PipelineNotFoundException,
ResourceNotFoundException,
InvalidTagsException,
TooManyTagsException,
)
from moto.core import BaseBackend, BaseModel
from moto.iam.models import ACCOUNT_ID
class CodePipeline(BaseModel):
def __init__(self, region, pipeline):
# the version number for a new pipeline is always 1
pipeline["version"] = 1
self.pipeline = self.add_default_values(pipeline)
self.tags = {}
self._arn = "arn:aws:codepipeline:{0}:{1}:{2}".format(
region, ACCOUNT_ID, pipeline["name"]
)
self._created = datetime.utcnow()
self._updated = datetime.utcnow()
@property
def metadata(self):
return {
"pipelineArn": self._arn,
"created": iso_8601_datetime_with_milliseconds(self._created),
"updated": iso_8601_datetime_with_milliseconds(self._updated),
}
def add_default_values(self, pipeline):
for stage in pipeline["stages"]:
for action in stage["actions"]:
if "runOrder" not in action:
action["runOrder"] = 1
if "configuration" not in action:
action["configuration"] = {}
if "outputArtifacts" not in action:
action["outputArtifacts"] = []
if "inputArtifacts" not in action:
action["inputArtifacts"] = []
return pipeline
def validate_tags(self, tags):
for tag in tags:
if tag["key"].startswith("aws:"):
raise InvalidTagsException(
"Not allowed to modify system tags. "
"System tags start with 'aws:'. "
"msg=[Caller is an end user and not allowed to mutate system tags]"
)
if (len(self.tags) + len(tags)) > 50:
raise TooManyTagsException(self._arn)
class CodePipelineBackend(BaseBackend):
def __init__(self):
self.pipelines = {}
@property
def iam_backend(self):
return iam_backends["global"]
def create_pipeline(self, region, pipeline, tags):
if pipeline["name"] in self.pipelines:
raise InvalidStructureException(
"A pipeline with the name '{0}' already exists in account '{1}'".format(
pipeline["name"], ACCOUNT_ID
)
)
try:
role = self.iam_backend.get_role_by_arn(pipeline["roleArn"])
service_principal = json.loads(role.assume_role_policy_document)[
"Statement"
][0]["Principal"]["Service"]
if "codepipeline.amazonaws.com" not in service_principal:
raise IAMNotFoundException("")
except IAMNotFoundException:
raise InvalidStructureException(
"CodePipeline is not authorized to perform AssumeRole on role {}".format(
pipeline["roleArn"]
)
)
if len(pipeline["stages"]) < 2:
raise InvalidStructureException(
"Pipeline has only 1 stage(s). There should be a minimum of 2 stages in a pipeline"
)
self.pipelines[pipeline["name"]] = CodePipeline(region, pipeline)
if tags:
self.pipelines[pipeline["name"]].validate_tags(tags)
new_tags = {tag["key"]: tag["value"] for tag in tags}
self.pipelines[pipeline["name"]].tags.update(new_tags)
return pipeline, sorted(tags, key=lambda i: i["key"])
def get_pipeline(self, name):
codepipeline = self.pipelines.get(name)
if not codepipeline:
raise PipelineNotFoundException(
"Account '{0}' does not have a pipeline with name '{1}'".format(
ACCOUNT_ID, name
)
)
return codepipeline.pipeline, codepipeline.metadata
def update_pipeline(self, pipeline):
codepipeline = self.pipelines.get(pipeline["name"])
if not codepipeline:
raise ResourceNotFoundException(
"The account with id '{0}' does not include a pipeline with the name '{1}'".format(
ACCOUNT_ID, pipeline["name"]
)
)
# version number is auto incremented
pipeline["version"] = codepipeline.pipeline["version"] + 1
codepipeline._updated = datetime.utcnow()
codepipeline.pipeline = codepipeline.add_default_values(pipeline)
return codepipeline.pipeline
def list_pipelines(self):
pipelines = []
for name, codepipeline in self.pipelines.items():
pipelines.append(
{
"name": name,
"version": codepipeline.pipeline["version"],
"created": codepipeline.metadata["created"],
"updated": codepipeline.metadata["updated"],
}
)
return sorted(pipelines, key=lambda i: i["name"])
def delete_pipeline(self, name):
self.pipelines.pop(name, None)
def list_tags_for_resource(self, arn):
name = arn.split(":")[-1]
pipeline = self.pipelines.get(name)
if not pipeline:
raise ResourceNotFoundException(
"The account with id '{0}' does not include a pipeline with the name '{1}'".format(
ACCOUNT_ID, name
)
)
tags = [{"key": key, "value": value} for key, value in pipeline.tags.items()]
return sorted(tags, key=lambda i: i["key"])
def tag_resource(self, arn, tags):
name = arn.split(":")[-1]
pipeline = self.pipelines.get(name)
if not pipeline:
raise ResourceNotFoundException(
"The account with id '{0}' does not include a pipeline with the name '{1}'".format(
ACCOUNT_ID, name
)
)
pipeline.validate_tags(tags)
for tag in tags:
pipeline.tags.update({tag["key"]: tag["value"]})
def untag_resource(self, arn, tag_keys):
name = arn.split(":")[-1]
pipeline = self.pipelines.get(name)
if not pipeline:
raise ResourceNotFoundException(
"The account with id '{0}' does not include a pipeline with the name '{1}'".format(
ACCOUNT_ID, name
)
)
for key in tag_keys:
pipeline.tags.pop(key, None)
codepipeline_backends = {}
for region in Session().get_available_regions("codepipeline"):
codepipeline_backends[region] = CodePipelineBackend()
for region in Session().get_available_regions(
"codepipeline", partition_name="aws-us-gov"
):
codepipeline_backends[region] = CodePipelineBackend()
for region in Session().get_available_regions("codepipeline", partition_name="aws-cn"):
codepipeline_backends[region] = CodePipelineBackend()

View File

@ -0,0 +1,62 @@
import json
from moto.core.responses import BaseResponse
from .models import codepipeline_backends
class CodePipelineResponse(BaseResponse):
@property
def codepipeline_backend(self):
return codepipeline_backends[self.region]
def create_pipeline(self):
pipeline, tags = self.codepipeline_backend.create_pipeline(
self.region, self._get_param("pipeline"), self._get_param("tags")
)
return json.dumps({"pipeline": pipeline, "tags": tags})
def get_pipeline(self):
pipeline, metadata = self.codepipeline_backend.get_pipeline(
self._get_param("name")
)
return json.dumps({"pipeline": pipeline, "metadata": metadata})
def update_pipeline(self):
pipeline = self.codepipeline_backend.update_pipeline(
self._get_param("pipeline")
)
return json.dumps({"pipeline": pipeline})
def list_pipelines(self):
pipelines = self.codepipeline_backend.list_pipelines()
return json.dumps({"pipelines": pipelines})
def delete_pipeline(self):
self.codepipeline_backend.delete_pipeline(self._get_param("name"))
return ""
def list_tags_for_resource(self):
tags = self.codepipeline_backend.list_tags_for_resource(
self._get_param("resourceArn")
)
return json.dumps({"tags": tags})
def tag_resource(self):
self.codepipeline_backend.tag_resource(
self._get_param("resourceArn"), self._get_param("tags")
)
return ""
def untag_resource(self):
self.codepipeline_backend.untag_resource(
self._get_param("resourceArn"), self._get_param("tagKeys")
)
return ""

View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .responses import CodePipelineResponse
url_bases = ["https?://codepipeline.(.+).amazonaws.com"]
url_paths = {"{0}/$": CodePipelineResponse.dispatch}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import cognitoidentity_backends
from ..core.models import base_decorator, deprecated_base_decorator
cognitoidentity_backend = cognitoidentity_backends['us-east-1']
cognitoidentity_backend = cognitoidentity_backends["us-east-1"]
mock_cognitoidentity = base_decorator(cognitoidentity_backends)
mock_cognitoidentity_deprecated = deprecated_base_decorator(cognitoidentity_backends)

View File

@ -0,0 +1,13 @@
from __future__ import unicode_literals
import json
from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest):
def __init__(self, message):
super(ResourceNotFoundError, self).__init__()
self.description = json.dumps(
{"message": message, "__type": "ResourceNotFoundException"}
)

View File

@ -3,32 +3,34 @@ from __future__ import unicode_literals
import datetime
import json
import boto.cognito.identity
from boto3 import Session
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from .exceptions import ResourceNotFoundError
from .utils import get_random_identity_id
class CognitoIdentity(BaseModel):
def __init__(self, region, identity_pool_name, **kwargs):
self.identity_pool_name = identity_pool_name
self.allow_unauthenticated_identities = kwargs.get('allow_unauthenticated_identities', '')
self.supported_login_providers = kwargs.get('supported_login_providers', {})
self.developer_provider_name = kwargs.get('developer_provider_name', '')
self.open_id_connect_provider_arns = kwargs.get('open_id_connect_provider_arns', [])
self.cognito_identity_providers = kwargs.get('cognito_identity_providers', [])
self.saml_provider_arns = kwargs.get('saml_provider_arns', [])
self.allow_unauthenticated_identities = kwargs.get(
"allow_unauthenticated_identities", ""
)
self.supported_login_providers = kwargs.get("supported_login_providers", {})
self.developer_provider_name = kwargs.get("developer_provider_name", "")
self.open_id_connect_provider_arns = kwargs.get(
"open_id_connect_provider_arns", []
)
self.cognito_identity_providers = kwargs.get("cognito_identity_providers", [])
self.saml_provider_arns = kwargs.get("saml_provider_arns", [])
self.identity_pool_id = get_random_identity_id(region)
self.creation_time = datetime.datetime.utcnow()
class CognitoIdentityBackend(BaseBackend):
def __init__(self, region):
super(CognitoIdentityBackend, self).__init__()
self.region = region
@ -39,34 +41,67 @@ class CognitoIdentityBackend(BaseBackend):
self.__dict__ = {}
self.__init__(region)
def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities,
supported_login_providers, developer_provider_name, open_id_connect_provider_arns,
cognito_identity_providers, saml_provider_arns):
def describe_identity_pool(self, identity_pool_id):
identity_pool = self.identity_pools.get(identity_pool_id, None)
new_identity = CognitoIdentity(self.region, identity_pool_name,
if not identity_pool:
raise ResourceNotFoundError(identity_pool)
response = json.dumps(
{
"AllowUnauthenticatedIdentities": identity_pool.allow_unauthenticated_identities,
"CognitoIdentityProviders": identity_pool.cognito_identity_providers,
"DeveloperProviderName": identity_pool.developer_provider_name,
"IdentityPoolId": identity_pool.identity_pool_id,
"IdentityPoolName": identity_pool.identity_pool_name,
"IdentityPoolTags": {},
"OpenIdConnectProviderARNs": identity_pool.open_id_connect_provider_arns,
"SamlProviderARNs": identity_pool.saml_provider_arns,
"SupportedLoginProviders": identity_pool.supported_login_providers,
}
)
return response
def create_identity_pool(
self,
identity_pool_name,
allow_unauthenticated_identities,
supported_login_providers,
developer_provider_name,
open_id_connect_provider_arns,
cognito_identity_providers,
saml_provider_arns,
):
new_identity = CognitoIdentity(
self.region,
identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers,
developer_provider_name=developer_provider_name,
open_id_connect_provider_arns=open_id_connect_provider_arns,
cognito_identity_providers=cognito_identity_providers,
saml_provider_arns=saml_provider_arns)
saml_provider_arns=saml_provider_arns,
)
self.identity_pools[new_identity.identity_pool_id] = new_identity
response = json.dumps({
'IdentityPoolId': new_identity.identity_pool_id,
'IdentityPoolName': new_identity.identity_pool_name,
'AllowUnauthenticatedIdentities': new_identity.allow_unauthenticated_identities,
'SupportedLoginProviders': new_identity.supported_login_providers,
'DeveloperProviderName': new_identity.developer_provider_name,
'OpenIdConnectProviderARNs': new_identity.open_id_connect_provider_arns,
'CognitoIdentityProviders': new_identity.cognito_identity_providers,
'SamlProviderARNs': new_identity.saml_provider_arns
})
response = json.dumps(
{
"IdentityPoolId": new_identity.identity_pool_id,
"IdentityPoolName": new_identity.identity_pool_name,
"AllowUnauthenticatedIdentities": new_identity.allow_unauthenticated_identities,
"SupportedLoginProviders": new_identity.supported_login_providers,
"DeveloperProviderName": new_identity.developer_provider_name,
"OpenIdConnectProviderARNs": new_identity.open_id_connect_provider_arns,
"CognitoIdentityProviders": new_identity.cognito_identity_providers,
"SamlProviderARNs": new_identity.saml_provider_arns,
}
)
return response
def get_id(self):
identity_id = {'IdentityId': get_random_identity_id(self.region)}
identity_id = {"IdentityId": get_random_identity_id(self.region)}
return json.dumps(identity_id)
def get_credentials_for_identity(self, identity_id):
@ -76,35 +111,38 @@ class CognitoIdentityBackend(BaseBackend):
expiration_str = str(iso_8601_datetime_with_milliseconds(expiration))
response = json.dumps(
{
"Credentials":
{
"Credentials": {
"AccessKeyId": "TESTACCESSKEY12345",
"Expiration": expiration_str,
"SecretKey": "ABCSECRETKEY",
"SessionToken": "ABC12345"
"SessionToken": "ABC12345",
},
"IdentityId": identity_id
})
"IdentityId": identity_id,
}
)
return response
def get_open_id_token_for_developer_identity(self, identity_id):
response = json.dumps(
{
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region)
})
{"IdentityId": identity_id, "Token": get_random_identity_id(self.region)}
)
return response
def get_open_id_token(self, identity_id):
response = json.dumps(
{
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region)
}
{"IdentityId": identity_id, "Token": get_random_identity_id(self.region)}
)
return response
cognitoidentity_backends = {}
for region in boto.cognito.identity.regions():
cognitoidentity_backends[region.name] = CognitoIdentityBackend(region.name)
for region in Session().get_available_regions("cognito-identity"):
cognitoidentity_backends[region] = CognitoIdentityBackend(region)
for region in Session().get_available_regions(
"cognito-identity", partition_name="aws-us-gov"
):
cognitoidentity_backends[region] = CognitoIdentityBackend(region)
for region in Session().get_available_regions(
"cognito-identity", partition_name="aws-cn"
):
cognitoidentity_backends[region] = CognitoIdentityBackend(region)

View File

@ -1,21 +1,22 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from .models import cognitoidentity_backends
from .utils import get_random_identity_id
class CognitoIdentityResponse(BaseResponse):
def create_identity_pool(self):
identity_pool_name = self._get_param('IdentityPoolName')
allow_unauthenticated_identities = self._get_param('AllowUnauthenticatedIdentities')
supported_login_providers = self._get_param('SupportedLoginProviders')
developer_provider_name = self._get_param('DeveloperProviderName')
open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs')
cognito_identity_providers = self._get_param('CognitoIdentityProviders')
saml_provider_arns = self._get_param('SamlProviderARNs')
identity_pool_name = self._get_param("IdentityPoolName")
allow_unauthenticated_identities = self._get_param(
"AllowUnauthenticatedIdentities"
)
supported_login_providers = self._get_param("SupportedLoginProviders")
developer_provider_name = self._get_param("DeveloperProviderName")
open_id_connect_provider_arns = self._get_param("OpenIdConnectProviderARNs")
cognito_identity_providers = self._get_param("CognitoIdentityProviders")
saml_provider_arns = self._get_param("SamlProviderARNs")
return cognitoidentity_backends[self.region].create_identity_pool(
identity_pool_name=identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities,
@ -23,17 +24,27 @@ class CognitoIdentityResponse(BaseResponse):
developer_provider_name=developer_provider_name,
open_id_connect_provider_arns=open_id_connect_provider_arns,
cognito_identity_providers=cognito_identity_providers,
saml_provider_arns=saml_provider_arns)
saml_provider_arns=saml_provider_arns,
)
def get_id(self):
return cognitoidentity_backends[self.region].get_id()
def describe_identity_pool(self):
return cognitoidentity_backends[self.region].describe_identity_pool(
self._get_param("IdentityPoolId")
)
def get_credentials_for_identity(self):
return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId'))
return cognitoidentity_backends[self.region].get_credentials_for_identity(
self._get_param("IdentityId")
)
def get_open_id_token_for_developer_identity(self):
return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity(
self._get_param('IdentityId') or get_random_identity_id(self.region)
return cognitoidentity_backends[
self.region
].get_open_id_token_for_developer_identity(
self._get_param("IdentityId") or get_random_identity_id(self.region)
)
def get_open_id_token(self):

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import CognitoIdentityResponse
url_bases = [
"https?://cognito-identity.(.+).amazonaws.com",
]
url_bases = ["https?://cognito-identity.(.+).amazonaws.com"]
url_paths = {
'{0}/$': CognitoIdentityResponse.dispatch,
}
url_paths = {"{0}/$": CognitoIdentityResponse.dispatch}

View File

@ -5,40 +5,40 @@ from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest):
def __init__(self, message):
super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'ResourceNotFoundException',
})
self.description = json.dumps(
{"message": message, "__type": "ResourceNotFoundException"}
)
class UserNotFoundError(BadRequest):
def __init__(self, message):
super(UserNotFoundError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'UserNotFoundException',
})
self.description = json.dumps(
{"message": message, "__type": "UserNotFoundException"}
)
class UsernameExistsException(BadRequest):
def __init__(self, message):
super(UsernameExistsException, self).__init__()
self.description = json.dumps(
{"message": message, "__type": "UsernameExistsException"}
)
class GroupExistsException(BadRequest):
def __init__(self, message):
super(GroupExistsException, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'GroupExistsException',
})
self.description = json.dumps(
{"message": message, "__type": "GroupExistsException"}
)
class NotAuthorizedError(BadRequest):
def __init__(self, message):
super(NotAuthorizedError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'NotAuthorizedException',
})
self.description = json.dumps(
{"message": message, "__type": "NotAuthorizedException"}
)

View File

@ -9,12 +9,19 @@ import os
import time
import uuid
import boto.cognito.identity
from boto3 import Session
from jose import jws
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from .exceptions import GroupExistsException, NotAuthorizedError, ResourceNotFoundError, UserNotFoundError
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
from .exceptions import (
GroupExistsException,
NotAuthorizedError,
ResourceNotFoundError,
UserNotFoundError,
UsernameExistsException,
)
UserStatus = {
"FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD",
@ -44,22 +51,28 @@ def paginate(limit, start_arg="next_token", limit_arg="max_results"):
def outer_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start = int(default_start if kwargs.get(start_arg) is None else kwargs[start_arg])
start = int(
default_start if kwargs.get(start_arg) is None else kwargs[start_arg]
)
lim = int(limit if kwargs.get(limit_arg) is None else kwargs[limit_arg])
stop = start + lim
result = func(*args, **kwargs)
limited_results = list(itertools.islice(result, start, stop))
next_token = stop if stop < len(result) else None
return limited_results, next_token
return wrapper
return outer_wrapper
class CognitoIdpUserPool(BaseModel):
def __init__(self, region, name, extended_config):
self.region = region
self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex))
self.arn = "arn:aws:cognito-idp:{}:{}:userpool/{}".format(
self.region, DEFAULT_ACCOUNT_ID, self.id
)
self.name = name
self.status = None
self.extended_config = extended_config or {}
@ -74,12 +87,15 @@ class CognitoIdpUserPool(BaseModel):
self.access_tokens = {}
self.id_tokens = {}
with open(os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")) as f:
with open(
os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")
) as f:
self.json_web_key = json.loads(f.read())
def _base_json(self):
return {
"Id": self.id,
"Arn": self.arn,
"Name": self.name,
"Status": self.status,
"CreationDate": time.mktime(self.creation_date.timetuple()),
@ -91,26 +107,35 @@ class CognitoIdpUserPool(BaseModel):
if extended:
user_pool_json.update(self.extended_config)
else:
user_pool_json["LambdaConfig"] = self.extended_config.get("LambdaConfig") or {}
user_pool_json["LambdaConfig"] = (
self.extended_config.get("LambdaConfig") or {}
)
return user_pool_json
def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}):
def create_jwt(
self, client_id, username, token_use, expires_in=60 * 60, extra_data={}
):
now = int(time.time())
payload = {
"iss": "https://cognito-idp.{}.amazonaws.com/{}".format(self.region, self.id),
"iss": "https://cognito-idp.{}.amazonaws.com/{}".format(
self.region, self.id
),
"sub": self.users[username].id,
"aud": client_id,
"token_use": "id",
"token_use": token_use,
"auth_time": now,
"exp": now + expires_in,
}
payload.update(extra_data)
return jws.sign(payload, self.json_web_key, algorithm='RS256'), expires_in
return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in
def create_id_token(self, client_id, username):
id_token, expires_in = self.create_jwt(client_id, username)
extra_data = self.get_user_extra_data_by_client_id(client_id, username)
id_token, expires_in = self.create_jwt(
client_id, username, "id", extra_data=extra_data
)
self.id_tokens[id_token] = (client_id, username)
return id_token, expires_in
@ -120,11 +145,7 @@ class CognitoIdpUserPool(BaseModel):
return refresh_token
def create_access_token(self, client_id, username):
extra_data = self.get_user_extra_data_by_client_id(
client_id, username
)
access_token, expires_in = self.create_jwt(client_id, username,
extra_data=extra_data)
access_token, expires_in = self.create_jwt(client_id, username, "access")
self.access_tokens[access_token] = (client_id, username)
return access_token, expires_in
@ -142,29 +163,27 @@ class CognitoIdpUserPool(BaseModel):
current_client = self.clients.get(client_id, None)
if current_client:
for readable_field in current_client.get_readable_fields():
attribute = list(filter(
lambda f: f['Name'] == readable_field,
self.users.get(username).attributes
))
attribute = list(
filter(
lambda f: f["Name"] == readable_field,
self.users.get(username).attributes,
)
)
if len(attribute) > 0:
extra_data.update({
attribute[0]['Name']: attribute[0]['Value']
})
extra_data.update({attribute[0]["Name"]: attribute[0]["Value"]})
return extra_data
class CognitoIdpUserPoolDomain(BaseModel):
def __init__(self, user_pool_id, domain, custom_domain_config=None):
self.user_pool_id = user_pool_id
self.domain = domain
self.custom_domain_config = custom_domain_config or {}
def _distribution_name(self):
if self.custom_domain_config and \
'CertificateArn' in self.custom_domain_config:
if self.custom_domain_config and "CertificateArn" in self.custom_domain_config:
hash = hashlib.md5(
self.custom_domain_config['CertificateArn'].encode('utf-8')
self.custom_domain_config["CertificateArn"].encode("utf-8")
).hexdigest()
return "{hash}.cloudfront.net".format(hash=hash[:16])
return None
@ -182,14 +201,11 @@ class CognitoIdpUserPoolDomain(BaseModel):
"Version": None,
}
elif distribution:
return {
"CloudFrontDomain": distribution,
}
return {"CloudFrontDomain": distribution}
return None
class CognitoIdpUserPoolClient(BaseModel):
def __init__(self, user_pool_id, extended_config):
self.user_pool_id = user_pool_id
self.id = str(uuid.uuid4())
@ -211,11 +227,10 @@ class CognitoIdpUserPoolClient(BaseModel):
return user_pool_client_json
def get_readable_fields(self):
return self.extended_config.get('ReadAttributes', [])
return self.extended_config.get("ReadAttributes", [])
class CognitoIdpIdentityProvider(BaseModel):
def __init__(self, name, extended_config):
self.name = name
self.extended_config = extended_config or {}
@ -239,7 +254,6 @@ class CognitoIdpIdentityProvider(BaseModel):
class CognitoIdpGroup(BaseModel):
def __init__(self, user_pool_id, group_name, description, role_arn, precedence):
self.user_pool_id = user_pool_id
self.group_name = group_name
@ -266,7 +280,6 @@ class CognitoIdpGroup(BaseModel):
class CognitoIdpUser(BaseModel):
def __init__(self, user_pool_id, username, password, status, attributes):
self.id = str(uuid.uuid4())
self.user_pool_id = user_pool_id
@ -299,19 +312,18 @@ class CognitoIdpUser(BaseModel):
{
"Enabled": self.enabled,
attributes_key: self.attributes,
"MFAOptions": []
"MFAOptions": [],
}
)
return user_json
def update_attributes(self, new_attributes):
def flatten_attrs(attrs):
return {attr['Name']: attr['Value'] for attr in attrs}
return {attr["Name"]: attr["Value"] for attr in attrs}
def expand_attrs(attrs):
return [{'Name': k, 'Value': v} for k, v in attrs.items()]
return [{"Name": k, "Value": v} for k, v in attrs.items()]
flat_attributes = flatten_attrs(self.attributes)
flat_attributes.update(flatten_attrs(new_attributes))
@ -319,7 +331,6 @@ class CognitoIdpUser(BaseModel):
class CognitoIdpBackend(BaseBackend):
def __init__(self, region):
super(CognitoIdpBackend, self).__init__()
self.region = region
@ -495,7 +506,9 @@ class CognitoIdpBackend(BaseBackend):
if not user_pool:
raise ResourceNotFoundError(user_pool_id)
group = CognitoIdpGroup(user_pool_id, group_name, description, role_arn, precedence)
group = CognitoIdpGroup(
user_pool_id, group_name, description, role_arn, precedence
)
if group.group_name in user_pool.groups:
raise GroupExistsException("A group with the name already exists")
user_pool.groups[group.group_name] = group
@ -556,12 +569,26 @@ class CognitoIdpBackend(BaseBackend):
user.groups.discard(group)
# User
def admin_create_user(self, user_pool_id, username, temporary_password, attributes):
def admin_create_user(
self, user_pool_id, username, message_action, temporary_password, attributes
):
user_pool = self.user_pools.get(user_pool_id)
if not user_pool:
raise ResourceNotFoundError(user_pool_id)
user = CognitoIdpUser(user_pool_id, username, temporary_password, UserStatus["FORCE_CHANGE_PASSWORD"], attributes)
if message_action and message_action == "RESEND":
if username not in user_pool.users:
raise UserNotFoundError(username)
elif username in user_pool.users:
raise UsernameExistsException(username)
user = CognitoIdpUser(
user_pool_id,
username,
temporary_password,
UserStatus["FORCE_CHANGE_PASSWORD"],
attributes,
)
user_pool.users[user.username] = user
return user
@ -607,7 +634,9 @@ class CognitoIdpBackend(BaseBackend):
def _log_user_in(self, user_pool, client, username):
refresh_token = user_pool.create_refresh_token(client.id, username)
access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token)
access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(
refresh_token
)
return {
"AuthenticationResult": {
@ -650,7 +679,11 @@ class CognitoIdpBackend(BaseBackend):
return self._log_user_in(user_pool, client, username)
elif auth_flow == "REFRESH_TOKEN":
refresh_token = auth_parameters.get("REFRESH_TOKEN")
id_token, access_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token)
(
id_token,
access_token,
expires_in,
) = user_pool.create_tokens_from_refresh_token(refresh_token)
return {
"AuthenticationResult": {
@ -662,7 +695,9 @@ class CognitoIdpBackend(BaseBackend):
else:
return {}
def respond_to_auth_challenge(self, session, client_id, challenge_name, challenge_responses):
def respond_to_auth_challenge(
self, session, client_id, challenge_name, challenge_responses
):
user_pool = self.sessions.get(session)
if not user_pool:
raise ResourceNotFoundError(session)
@ -726,8 +761,14 @@ class CognitoIdpBackend(BaseBackend):
cognitoidp_backends = {}
for region in boto.cognito.identity.regions():
cognitoidp_backends[region.name] = CognitoIdpBackend(region.name)
for region in Session().get_available_regions("cognito-idp"):
cognitoidp_backends[region] = CognitoIdpBackend(region)
for region in Session().get_available_regions(
"cognito-idp", partition_name="aws-us-gov"
):
cognitoidp_backends[region] = CognitoIdpBackend(region)
for region in Session().get_available_regions("cognito-idp", partition_name="aws-cn"):
cognitoidp_backends[region] = CognitoIdpBackend(region)
# Hack to help moto-server process requests on localhost, where the region isn't

View File

@ -8,7 +8,6 @@ from .models import cognitoidp_backends, find_region_by_value
class CognitoIdpResponse(BaseResponse):
@property
def parameters(self):
return json.loads(self.body)
@ -16,10 +15,10 @@ class CognitoIdpResponse(BaseResponse):
# User pool
def create_user_pool(self):
name = self.parameters.pop("PoolName")
user_pool = cognitoidp_backends[self.region].create_user_pool(name, self.parameters)
return json.dumps({
"UserPool": user_pool.to_json(extended=True)
})
user_pool = cognitoidp_backends[self.region].create_user_pool(
name, self.parameters
)
return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def list_user_pools(self):
max_results = self._get_param("MaxResults")
@ -27,9 +26,7 @@ class CognitoIdpResponse(BaseResponse):
user_pools, next_token = cognitoidp_backends[self.region].list_user_pools(
max_results=max_results, next_token=next_token
)
response = {
"UserPools": [user_pool.to_json() for user_pool in user_pools],
}
response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]}
if next_token:
response["NextToken"] = str(next_token)
return json.dumps(response)
@ -37,9 +34,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool(self):
user_pool_id = self._get_param("UserPoolId")
user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id)
return json.dumps({
"UserPool": user_pool.to_json(extended=True)
})
return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def delete_user_pool(self):
user_pool_id = self._get_param("UserPoolId")
@ -61,14 +56,14 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_domain(self):
domain = self._get_param("Domain")
user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(domain)
user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(
domain
)
domain_description = {}
if user_pool_domain:
domain_description = user_pool_domain.to_json()
return json.dumps({
"DomainDescription": domain_description
})
return json.dumps({"DomainDescription": domain_description})
def delete_user_pool_domain(self):
domain = self._get_param("Domain")
@ -89,19 +84,24 @@ class CognitoIdpResponse(BaseResponse):
# User pool client
def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId")
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(user_pool_id, self.parameters)
return json.dumps({
"UserPoolClient": user_pool_client.to_json(extended=True)
})
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(
user_pool_id, self.parameters
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def list_user_pool_clients(self):
user_pool_id = self._get_param("UserPoolId")
max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken", "0")
user_pool_clients, next_token = cognitoidp_backends[self.region].list_user_pool_clients(user_pool_id,
max_results=max_results, next_token=next_token)
user_pool_clients, next_token = cognitoidp_backends[
self.region
].list_user_pool_clients(
user_pool_id, max_results=max_results, next_token=next_token
)
response = {
"UserPoolClients": [user_pool_client.to_json() for user_pool_client in user_pool_clients]
"UserPoolClients": [
user_pool_client.to_json() for user_pool_client in user_pool_clients
]
}
if next_token:
response["NextToken"] = str(next_token)
@ -110,43 +110,51 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId")
user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(user_pool_id, client_id)
return json.dumps({
"UserPoolClient": user_pool_client.to_json(extended=True)
})
user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(
user_pool_id, client_id
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def update_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId")
client_id = self.parameters.pop("ClientId")
user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(user_pool_id, client_id, self.parameters)
return json.dumps({
"UserPoolClient": user_pool_client.to_json(extended=True)
})
user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(
user_pool_id, client_id, self.parameters
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def delete_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId")
cognitoidp_backends[self.region].delete_user_pool_client(user_pool_id, client_id)
cognitoidp_backends[self.region].delete_user_pool_client(
user_pool_id, client_id
)
return ""
# Identity provider
def create_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self.parameters.pop("ProviderName")
identity_provider = cognitoidp_backends[self.region].create_identity_provider(user_pool_id, name, self.parameters)
return json.dumps({
"IdentityProvider": identity_provider.to_json(extended=True)
})
identity_provider = cognitoidp_backends[self.region].create_identity_provider(
user_pool_id, name, self.parameters
)
return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def list_identity_providers(self):
user_pool_id = self._get_param("UserPoolId")
max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken", "0")
identity_providers, next_token = cognitoidp_backends[self.region].list_identity_providers(
identity_providers, next_token = cognitoidp_backends[
self.region
].list_identity_providers(
user_pool_id, max_results=max_results, next_token=next_token
)
response = {
"Providers": [identity_provider.to_json() for identity_provider in identity_providers]
"Providers": [
identity_provider.to_json() for identity_provider in identity_providers
]
}
if next_token:
response["NextToken"] = str(next_token)
@ -155,18 +163,22 @@ class CognitoIdpResponse(BaseResponse):
def describe_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].describe_identity_provider(user_pool_id, name)
return json.dumps({
"IdentityProvider": identity_provider.to_json(extended=True)
})
identity_provider = cognitoidp_backends[self.region].describe_identity_provider(
user_pool_id, name
)
return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def update_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].update_identity_provider(user_pool_id, name, self.parameters)
return json.dumps({
"IdentityProvider": identity_provider.to_json(extended=True)
})
identity_provider = cognitoidp_backends[self.region].update_identity_provider(
user_pool_id, name, self.parameters
)
return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def delete_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
@ -183,31 +195,21 @@ class CognitoIdpResponse(BaseResponse):
precedence = self._get_param("Precedence")
group = cognitoidp_backends[self.region].create_group(
user_pool_id,
group_name,
description,
role_arn,
precedence,
user_pool_id, group_name, description, role_arn, precedence
)
return json.dumps({
"Group": group.to_json(),
})
return json.dumps({"Group": group.to_json()})
def get_group(self):
group_name = self._get_param("GroupName")
user_pool_id = self._get_param("UserPoolId")
group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name)
return json.dumps({
"Group": group.to_json(),
})
return json.dumps({"Group": group.to_json()})
def list_groups(self):
user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].list_groups(user_pool_id)
return json.dumps({
"Groups": [group.to_json() for group in groups],
})
return json.dumps({"Groups": [group.to_json() for group in groups]})
def delete_group(self):
group_name = self._get_param("GroupName")
@ -221,9 +223,7 @@ class CognitoIdpResponse(BaseResponse):
group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_add_user_to_group(
user_pool_id,
group_name,
username,
user_pool_id, group_name, username
)
return ""
@ -231,18 +231,18 @@ class CognitoIdpResponse(BaseResponse):
def list_users_in_group(self):
user_pool_id = self._get_param("UserPoolId")
group_name = self._get_param("GroupName")
users = cognitoidp_backends[self.region].list_users_in_group(user_pool_id, group_name)
return json.dumps({
"Users": [user.to_json(extended=True) for user in users],
})
users = cognitoidp_backends[self.region].list_users_in_group(
user_pool_id, group_name
)
return json.dumps({"Users": [user.to_json(extended=True) for user in users]})
def admin_list_groups_for_user(self):
username = self._get_param("Username")
user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].admin_list_groups_for_user(user_pool_id, username)
return json.dumps({
"Groups": [group.to_json() for group in groups],
})
groups = cognitoidp_backends[self.region].admin_list_groups_for_user(
user_pool_id, username
)
return json.dumps({"Groups": [group.to_json() for group in groups]})
def admin_remove_user_from_group(self):
user_pool_id = self._get_param("UserPoolId")
@ -250,9 +250,7 @@ class CognitoIdpResponse(BaseResponse):
group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_remove_user_from_group(
user_pool_id,
group_name,
username,
user_pool_id, group_name, username
)
return ""
@ -261,33 +259,40 @@ class CognitoIdpResponse(BaseResponse):
def admin_create_user(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
message_action = self._get_param("MessageAction")
temporary_password = self._get_param("TemporaryPassword")
user = cognitoidp_backends[self.region].admin_create_user(
user_pool_id,
username,
message_action,
temporary_password,
self._get_param("UserAttributes", [])
self._get_param("UserAttributes", []),
)
return json.dumps({
"User": user.to_json(extended=True)
})
return json.dumps({"User": user.to_json(extended=True)})
def admin_get_user(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username)
return json.dumps(
user.to_json(extended=True, attributes_key="UserAttributes")
)
return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes"))
def list_users(self):
user_pool_id = self._get_param("UserPoolId")
limit = self._get_param("Limit")
token = self._get_param("PaginationToken")
users, token = cognitoidp_backends[self.region].list_users(user_pool_id,
limit=limit,
pagination_token=token)
filt = self._get_param("Filter")
users, token = cognitoidp_backends[self.region].list_users(
user_pool_id, limit=limit, pagination_token=token
)
if filt:
name, value = filt.replace('"', "").split("=")
users = [
user
for user in users
for attribute in user.attributes
if attribute["Name"] == name and attribute["Value"] == value
]
response = {"Users": [user.to_json(extended=True) for user in users]}
if token:
response["PaginationToken"] = str(token)
@ -318,10 +323,7 @@ class CognitoIdpResponse(BaseResponse):
auth_parameters = self._get_param("AuthParameters")
auth_result = cognitoidp_backends[self.region].admin_initiate_auth(
user_pool_id,
client_id,
auth_flow,
auth_parameters,
user_pool_id, client_id, auth_flow, auth_parameters
)
return json.dumps(auth_result)
@ -332,21 +334,15 @@ class CognitoIdpResponse(BaseResponse):
challenge_name = self._get_param("ChallengeName")
challenge_responses = self._get_param("ChallengeResponses")
auth_result = cognitoidp_backends[self.region].respond_to_auth_challenge(
session,
client_id,
challenge_name,
challenge_responses,
session, client_id, challenge_name, challenge_responses
)
return json.dumps(auth_result)
def forgot_password(self):
return json.dumps({
"CodeDeliveryDetails": {
"DeliveryMedium": "EMAIL",
"Destination": "...",
}
})
return json.dumps(
{"CodeDeliveryDetails": {"DeliveryMedium": "EMAIL", "Destination": "..."}}
)
# This endpoint receives no authorization header, so if moto-server is listening
# on localhost (doesn't get a region in the host header), it doesn't know what
@ -357,7 +353,9 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username")
password = self._get_param("Password")
region = find_region_by_value("client_id", client_id)
cognitoidp_backends[region].confirm_forgot_password(client_id, username, password)
cognitoidp_backends[region].confirm_forgot_password(
client_id, username, password
)
return ""
# Ditto the comment on confirm_forgot_password.
@ -366,21 +364,26 @@ class CognitoIdpResponse(BaseResponse):
previous_password = self._get_param("PreviousPassword")
proposed_password = self._get_param("ProposedPassword")
region = find_region_by_value("access_token", access_token)
cognitoidp_backends[region].change_password(access_token, previous_password, proposed_password)
cognitoidp_backends[region].change_password(
access_token, previous_password, proposed_password
)
return ""
def admin_update_user_attributes(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
attributes = self._get_param("UserAttributes")
cognitoidp_backends[self.region].admin_update_user_attributes(user_pool_id, username, attributes)
cognitoidp_backends[self.region].admin_update_user_attributes(
user_pool_id, username, attributes
)
return ""
class CognitoIdpJsonWebKeyResponse(BaseResponse):
def __init__(self):
with open(os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")) as f:
with open(
os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")
) as f:
self.json_web_key = f.read()
def serve_json_web_key(self, request, full_url, headers):

View File

@ -1,11 +1,9 @@
from __future__ import unicode_literals
from .responses import CognitoIdpResponse, CognitoIdpJsonWebKeyResponse
url_bases = [
"https?://cognito-idp.(.+).amazonaws.com",
]
url_bases = ["https?://cognito-idp.(.+).amazonaws.com"]
url_paths = {
'{0}/$': CognitoIdpResponse.dispatch,
'{0}/<user_pool_id>/.well-known/jwks.json$': CognitoIdpJsonWebKeyResponse().serve_json_web_key,
"{0}/$": CognitoIdpResponse.dispatch,
"{0}/<user_pool_id>/.well-known/jwks.json$": CognitoIdpJsonWebKeyResponse().serve_json_web_key,
}

View File

@ -1,5 +1,10 @@
try:
from collections import OrderedDict # flake8: noqa
from collections import OrderedDict # noqa
except ImportError:
# python 2.6 or earlier, use backport
from ordereddict import OrderedDict # flake8: noqa
from ordereddict import OrderedDict # noqa
try:
import collections.abc as collections_abc # noqa
except ImportError:
import collections as collections_abc # noqa

View File

@ -6,8 +6,12 @@ class NameTooLongException(JsonRESTError):
code = 400
def __init__(self, name, location):
message = '1 validation error detected: Value \'{name}\' at \'{location}\' failed to satisfy' \
' constraint: Member must have length less than or equal to 256'.format(name=name, location=location)
message = (
"1 validation error detected: Value '{name}' at '{location}' failed to satisfy"
" constraint: Member must have length less than or equal to 256".format(
name=name, location=location
)
)
super(NameTooLongException, self).__init__("ValidationException", message)
@ -15,135 +19,360 @@ class InvalidConfigurationRecorderNameException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'The configuration recorder name \'{name}\' is not valid, blank string.'.format(name=name)
super(InvalidConfigurationRecorderNameException, self).__init__("InvalidConfigurationRecorderNameException",
message)
message = "The configuration recorder name '{name}' is not valid, blank string.".format(
name=name
)
super(InvalidConfigurationRecorderNameException, self).__init__(
"InvalidConfigurationRecorderNameException", message
)
class MaxNumberOfConfigurationRecordersExceededException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Failed to put configuration recorder \'{name}\' because the maximum number of ' \
'configuration recorders: 1 is reached.'.format(name=name)
message = (
"Failed to put configuration recorder '{name}' because the maximum number of "
"configuration recorders: 1 is reached.".format(name=name)
)
super(MaxNumberOfConfigurationRecordersExceededException, self).__init__(
"MaxNumberOfConfigurationRecordersExceededException", message)
"MaxNumberOfConfigurationRecordersExceededException", message
)
class InvalidRecordingGroupException(JsonRESTError):
code = 400
def __init__(self):
message = 'The recording group provided is not valid'
super(InvalidRecordingGroupException, self).__init__("InvalidRecordingGroupException", message)
message = "The recording group provided is not valid"
super(InvalidRecordingGroupException, self).__init__(
"InvalidRecordingGroupException", message
)
class InvalidResourceTypeException(JsonRESTError):
code = 400
def __init__(self, bad_list, good_list):
message = '{num} validation error detected: Value \'{bad_list}\' at ' \
'\'configurationRecorder.recordingGroup.resourceTypes\' failed to satisfy constraint: ' \
'Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]'.format(
num=len(bad_list), bad_list=bad_list, good_list=good_list)
message = (
"{num} validation error detected: Value '{bad_list}' at "
"'configurationRecorder.recordingGroup.resourceTypes' failed to satisfy constraint: "
"Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]".format(
num=len(bad_list), bad_list=bad_list, good_list=good_list
)
)
# For PY2:
message = str(message)
super(InvalidResourceTypeException, self).__init__("ValidationException", message)
super(InvalidResourceTypeException, self).__init__(
"ValidationException", message
)
class NoSuchConfigurationAggregatorException(JsonRESTError):
code = 400
def __init__(self, number=1):
if number == 1:
message = "The configuration aggregator does not exist. Check the configuration aggregator name and try again."
else:
message = (
"At least one of the configuration aggregators does not exist. Check the configuration aggregator"
" names and try again."
)
super(NoSuchConfigurationAggregatorException, self).__init__(
"NoSuchConfigurationAggregatorException", message
)
class NoSuchConfigurationRecorderException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Cannot find configuration recorder with the specified name \'{name}\'.'.format(name=name)
super(NoSuchConfigurationRecorderException, self).__init__("NoSuchConfigurationRecorderException", message)
message = "Cannot find configuration recorder with the specified name '{name}'.".format(
name=name
)
super(NoSuchConfigurationRecorderException, self).__init__(
"NoSuchConfigurationRecorderException", message
)
class InvalidDeliveryChannelNameException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'The delivery channel name \'{name}\' is not valid, blank string.'.format(name=name)
super(InvalidDeliveryChannelNameException, self).__init__("InvalidDeliveryChannelNameException",
message)
message = "The delivery channel name '{name}' is not valid, blank string.".format(
name=name
)
super(InvalidDeliveryChannelNameException, self).__init__(
"InvalidDeliveryChannelNameException", message
)
class NoSuchBucketException(JsonRESTError):
"""We are *only* validating that there is value that is not '' here."""
code = 400
def __init__(self):
message = 'Cannot find a S3 bucket with an empty bucket name.'
message = "Cannot find a S3 bucket with an empty bucket name."
super(NoSuchBucketException, self).__init__("NoSuchBucketException", message)
class InvalidNextTokenException(JsonRESTError):
code = 400
def __init__(self):
message = "The nextToken provided is invalid"
super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", message
)
class InvalidS3KeyPrefixException(JsonRESTError):
code = 400
def __init__(self):
message = 'The s3 key prefix \'\' is not valid, empty s3 key prefix.'
super(InvalidS3KeyPrefixException, self).__init__("InvalidS3KeyPrefixException", message)
message = "The s3 key prefix '' is not valid, empty s3 key prefix."
super(InvalidS3KeyPrefixException, self).__init__(
"InvalidS3KeyPrefixException", message
)
class InvalidSNSTopicARNException(JsonRESTError):
"""We are *only* validating that there is value that is not '' here."""
code = 400
def __init__(self):
message = 'The sns topic arn \'\' is not valid.'
super(InvalidSNSTopicARNException, self).__init__("InvalidSNSTopicARNException", message)
message = "The sns topic arn '' is not valid."
super(InvalidSNSTopicARNException, self).__init__(
"InvalidSNSTopicARNException", message
)
class InvalidDeliveryFrequency(JsonRESTError):
code = 400
def __init__(self, value, good_list):
message = '1 validation error detected: Value \'{value}\' at ' \
'\'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency\' failed to satisfy ' \
'constraint: Member must satisfy enum value set: {good_list}'.format(value=value, good_list=good_list)
super(InvalidDeliveryFrequency, self).__init__("InvalidDeliveryFrequency", message)
message = (
"1 validation error detected: Value '{value}' at "
"'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency' failed to satisfy "
"constraint: Member must satisfy enum value set: {good_list}".format(
value=value, good_list=good_list
)
)
super(InvalidDeliveryFrequency, self).__init__(
"InvalidDeliveryFrequency", message
)
class MaxNumberOfDeliveryChannelsExceededException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Failed to put delivery channel \'{name}\' because the maximum number of ' \
'delivery channels: 1 is reached.'.format(name=name)
message = (
"Failed to put delivery channel '{name}' because the maximum number of "
"delivery channels: 1 is reached.".format(name=name)
)
super(MaxNumberOfDeliveryChannelsExceededException, self).__init__(
"MaxNumberOfDeliveryChannelsExceededException", message)
"MaxNumberOfDeliveryChannelsExceededException", message
)
class NoSuchDeliveryChannelException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Cannot find delivery channel with specified name \'{name}\'.'.format(name=name)
super(NoSuchDeliveryChannelException, self).__init__("NoSuchDeliveryChannelException", message)
message = "Cannot find delivery channel with specified name '{name}'.".format(
name=name
)
super(NoSuchDeliveryChannelException, self).__init__(
"NoSuchDeliveryChannelException", message
)
class NoAvailableConfigurationRecorderException(JsonRESTError):
code = 400
def __init__(self):
message = 'Configuration recorder is not available to put delivery channel.'
super(NoAvailableConfigurationRecorderException, self).__init__("NoAvailableConfigurationRecorderException",
message)
message = "Configuration recorder is not available to put delivery channel."
super(NoAvailableConfigurationRecorderException, self).__init__(
"NoAvailableConfigurationRecorderException", message
)
class NoAvailableDeliveryChannelException(JsonRESTError):
code = 400
def __init__(self):
message = 'Delivery channel is not available to start configuration recorder.'
super(NoAvailableDeliveryChannelException, self).__init__("NoAvailableDeliveryChannelException", message)
message = "Delivery channel is not available to start configuration recorder."
super(NoAvailableDeliveryChannelException, self).__init__(
"NoAvailableDeliveryChannelException", message
)
class LastDeliveryChannelDeleteFailedException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \
'because there is a running configuration recorder.'.format(name=name)
super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message)
message = (
"Failed to delete last specified delivery channel with name '{name}', because there, "
"because there is a running configuration recorder.".format(name=name)
)
super(LastDeliveryChannelDeleteFailedException, self).__init__(
"LastDeliveryChannelDeleteFailedException", message
)
class TooManyAccountSources(JsonRESTError):
code = 400
def __init__(self, length):
locations = ["com.amazonaws.xyz"] * length
message = (
"Value '[{locations}]' at 'accountAggregationSources' failed to satisfy constraint: "
"Member must have length less than or equal to 1".format(
locations=", ".join(locations)
)
)
super(TooManyAccountSources, self).__init__("ValidationException", message)
class DuplicateTags(JsonRESTError):
code = 400
def __init__(self):
super(DuplicateTags, self).__init__(
"InvalidInput",
"Duplicate tag keys found. Please note that Tag keys are case insensitive.",
)
class TagKeyTooBig(JsonRESTError):
code = 400
def __init__(self, tag, param="tags.X.member.key"):
super(TagKeyTooBig, self).__init__(
"ValidationException",
"1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 128".format(
tag, param
),
)
class TagValueTooBig(JsonRESTError):
code = 400
def __init__(self, tag):
super(TagValueTooBig, self).__init__(
"ValidationException",
"1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy "
"constraint: Member must have length less than or equal to 256".format(tag),
)
class InvalidParameterValueException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValueException, self).__init__(
"InvalidParameterValueException", message
)
class InvalidTagCharacters(JsonRESTError):
code = 400
def __init__(self, tag, param="tags.X.member.key"):
message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(
tag, param
)
message += "constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+"
super(InvalidTagCharacters, self).__init__("ValidationException", message)
class TooManyTags(JsonRESTError):
code = 400
def __init__(self, tags, param="tags"):
super(TooManyTags, self).__init__(
"ValidationException",
"1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 50.".format(
tags, param
),
)
class InvalidResourceParameters(JsonRESTError):
code = 400
def __init__(self):
super(InvalidResourceParameters, self).__init__(
"ValidationException",
"Both Resource ID and Resource Name " "cannot be specified in the request",
)
class InvalidLimit(JsonRESTError):
code = 400
def __init__(self, value):
super(InvalidLimit, self).__init__(
"ValidationException",
"Value '{value}' at 'limit' failed to satisify constraint: Member"
" must have value less than or equal to 100".format(value=value),
)
class TooManyResourceIds(JsonRESTError):
code = 400
def __init__(self):
super(TooManyResourceIds, self).__init__(
"ValidationException",
"The specified list had more than 20 resource ID's. "
"It must have '20' or less items",
)
class ResourceNotDiscoveredException(JsonRESTError):
code = 400
def __init__(self, type, resource):
super(ResourceNotDiscoveredException, self).__init__(
"ResourceNotDiscoveredException",
"Resource {resource} of resourceType:{type} is unknown or has not been "
"discovered".format(resource=resource, type=type),
)
class TooManyResourceKeys(JsonRESTError):
code = 400
def __init__(self, bad_list):
message = (
"1 validation error detected: Value '{bad_list}' at "
"'resourceKeys' failed to satisfy constraint: "
"Member must have length less than or equal to 100".format(
bad_list=bad_list
)
)
# For PY2:
message = str(message)
super(TooManyResourceKeys, self).__init__("ValidationException", message)
class InvalidResultTokenException(JsonRESTError):
code = 400
def __init__(self):
message = "The resultToken provided is invalid"
super(InvalidResultTokenException, self).__init__(
"InvalidResultTokenException", message
)

File diff suppressed because it is too large Load Diff

View File

@ -4,50 +4,158 @@ from .models import config_backends
class ConfigResponse(BaseResponse):
@property
def config_backend(self):
return config_backends[self.region]
def put_configuration_recorder(self):
self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder'))
self.config_backend.put_configuration_recorder(
self._get_param("ConfigurationRecorder")
)
return ""
def put_configuration_aggregator(self):
aggregator = self.config_backend.put_configuration_aggregator(
json.loads(self.body), self.region
)
schema = {"ConfigurationAggregator": aggregator}
return json.dumps(schema)
def describe_configuration_aggregators(self):
aggregators = self.config_backend.describe_configuration_aggregators(
self._get_param("ConfigurationAggregatorNames"),
self._get_param("NextToken"),
self._get_param("Limit"),
)
return json.dumps(aggregators)
def delete_configuration_aggregator(self):
self.config_backend.delete_configuration_aggregator(
self._get_param("ConfigurationAggregatorName")
)
return ""
def put_aggregation_authorization(self):
agg_auth = self.config_backend.put_aggregation_authorization(
self.region,
self._get_param("AuthorizedAccountId"),
self._get_param("AuthorizedAwsRegion"),
self._get_param("Tags"),
)
schema = {"AggregationAuthorization": agg_auth}
return json.dumps(schema)
def describe_aggregation_authorizations(self):
authorizations = self.config_backend.describe_aggregation_authorizations(
self._get_param("NextToken"), self._get_param("Limit")
)
return json.dumps(authorizations)
def delete_aggregation_authorization(self):
self.config_backend.delete_aggregation_authorization(
self._get_param("AuthorizedAccountId"),
self._get_param("AuthorizedAwsRegion"),
)
return ""
def describe_configuration_recorders(self):
recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames'))
schema = {'ConfigurationRecorders': recorders}
recorders = self.config_backend.describe_configuration_recorders(
self._get_param("ConfigurationRecorderNames")
)
schema = {"ConfigurationRecorders": recorders}
return json.dumps(schema)
def describe_configuration_recorder_status(self):
recorder_statuses = self.config_backend.describe_configuration_recorder_status(
self._get_param('ConfigurationRecorderNames'))
schema = {'ConfigurationRecordersStatus': recorder_statuses}
self._get_param("ConfigurationRecorderNames")
)
schema = {"ConfigurationRecordersStatus": recorder_statuses}
return json.dumps(schema)
def put_delivery_channel(self):
self.config_backend.put_delivery_channel(self._get_param('DeliveryChannel'))
self.config_backend.put_delivery_channel(self._get_param("DeliveryChannel"))
return ""
def describe_delivery_channels(self):
delivery_channels = self.config_backend.describe_delivery_channels(self._get_param('DeliveryChannelNames'))
schema = {'DeliveryChannels': delivery_channels}
delivery_channels = self.config_backend.describe_delivery_channels(
self._get_param("DeliveryChannelNames")
)
schema = {"DeliveryChannels": delivery_channels}
return json.dumps(schema)
def describe_delivery_channel_status(self):
raise NotImplementedError()
def delete_delivery_channel(self):
self.config_backend.delete_delivery_channel(self._get_param('DeliveryChannelName'))
self.config_backend.delete_delivery_channel(
self._get_param("DeliveryChannelName")
)
return ""
def delete_configuration_recorder(self):
self.config_backend.delete_configuration_recorder(self._get_param('ConfigurationRecorderName'))
self.config_backend.delete_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return ""
def start_configuration_recorder(self):
self.config_backend.start_configuration_recorder(self._get_param('ConfigurationRecorderName'))
self.config_backend.start_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return ""
def stop_configuration_recorder(self):
self.config_backend.stop_configuration_recorder(self._get_param('ConfigurationRecorderName'))
self.config_backend.stop_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return ""
def list_discovered_resources(self):
schema = self.config_backend.list_discovered_resources(
self._get_param("resourceType"),
self.region,
self._get_param("resourceIds"),
self._get_param("resourceName"),
self._get_param("limit"),
self._get_param("nextToken"),
)
return json.dumps(schema)
def list_aggregate_discovered_resources(self):
schema = self.config_backend.list_aggregate_discovered_resources(
self._get_param("ConfigurationAggregatorName"),
self._get_param("ResourceType"),
self._get_param("Filters"),
self._get_param("Limit"),
self._get_param("NextToken"),
)
return json.dumps(schema)
def get_resource_config_history(self):
schema = self.config_backend.get_resource_config_history(
self._get_param("resourceType"), self._get_param("resourceId"), self.region
)
return json.dumps(schema)
def batch_get_resource_config(self):
schema = self.config_backend.batch_get_resource_config(
self._get_param("resourceKeys"), self.region
)
return json.dumps(schema)
def batch_get_aggregate_resource_config(self):
schema = self.config_backend.batch_get_aggregate_resource_config(
self._get_param("ConfigurationAggregatorName"),
self._get_param("ResourceIdentifiers"),
)
return json.dumps(schema)
def put_evaluations(self):
evaluations = self.config_backend.put_evaluations(
self._get_param("Evaluations"),
self._get_param("ResultToken"),
self._get_param("TestMode"),
)
return json.dumps(evaluations)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import ConfigResponse
url_bases = [
"https?://config.(.+).amazonaws.com",
]
url_bases = ["https?://config.(.+).amazonaws.com"]
url_paths = {
'{0}/$': ConfigResponse.dispatch,
}
url_paths = {"{0}/$": ConfigResponse.dispatch}

View File

@ -1,4 +1,9 @@
from __future__ import unicode_literals
from .models import BaseModel, BaseBackend, moto_api_backend # flake8: noqa
from .models import BaseModel, BaseBackend, moto_api_backend, ACCOUNT_ID # noqa
from .responses import ActionAuthenticatorMixin
moto_api_backends = {"global": moto_api_backend}
set_initial_no_auth_action_count = (
ActionAuthenticatorMixin.set_initial_no_auth_action_count
)

404
moto/core/access_control.py Normal file
View File

@ -0,0 +1,404 @@
"""
This implementation is NOT complete, there are many things to improve.
The following is a list of the most important missing features and inaccuracies.
TODO add support for more principals, apart from IAM users and assumed IAM roles
TODO add support for the Resource and Condition parts of IAM policies
TODO add support and create tests for all services in moto (for example, API Gateway is probably not supported currently)
TODO implement service specific error messages (currently, EC2 and S3 are supported separately, everything else defaults to the errors IAM returns)
TODO include information about the action's resource in error messages (once the Resource element in IAM policies is supported)
TODO check all other actions that are performed by the action called by the user (for example, autoscaling:CreateAutoScalingGroup requires permission for iam:CreateServiceLinkedRole too - see https://docs.aws.amazon.com/autoscaling/ec2/userguide/control-access-using-iam.html)
TODO add support for resource-based policies
"""
import json
import logging
import re
from abc import abstractmethod, ABCMeta
from enum import Enum
import six
from botocore.auth import SigV4Auth, S3SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
from six import string_types
from moto.core import ACCOUNT_ID
from moto.iam.models import Policy
from moto.iam import iam_backend
from moto.core.exceptions import (
SignatureDoesNotMatchError,
AccessDeniedError,
InvalidClientTokenIdError,
AuthFailureError,
)
from moto.s3.exceptions import (
BucketAccessDeniedError,
S3AccessDeniedError,
BucketInvalidTokenError,
S3InvalidTokenError,
S3InvalidAccessKeyIdError,
BucketInvalidAccessKeyIdError,
BucketSignatureDoesNotMatchError,
S3SignatureDoesNotMatchError,
)
from moto.sts import sts_backend
log = logging.getLogger(__name__)
def create_access_key(access_key_id, headers):
if access_key_id.startswith("AKIA") or "X-Amz-Security-Token" not in headers:
return IAMUserAccessKey(access_key_id, headers)
else:
return AssumedRoleAccessKey(access_key_id, headers)
class IAMUserAccessKey(object):
def __init__(self, access_key_id, headers):
iam_users = iam_backend.list_users("/", None, None)
for iam_user in iam_users:
for access_key in iam_user.access_keys:
if access_key.access_key_id == access_key_id:
self._owner_user_name = iam_user.name
self._access_key_id = access_key_id
self._secret_access_key = access_key.secret_access_key
if "X-Amz-Security-Token" in headers:
raise CreateAccessKeyFailure(reason="InvalidToken")
return
raise CreateAccessKeyFailure(reason="InvalidId")
@property
def arn(self):
return "arn:aws:iam::{account_id}:user/{iam_user_name}".format(
account_id=ACCOUNT_ID, iam_user_name=self._owner_user_name
)
def create_credentials(self):
return Credentials(self._access_key_id, self._secret_access_key)
def collect_policies(self):
user_policies = []
inline_policy_names = iam_backend.list_user_policies(self._owner_user_name)
for inline_policy_name in inline_policy_names:
inline_policy = iam_backend.get_user_policy(
self._owner_user_name, inline_policy_name
)
user_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_user_policies(
self._owner_user_name
)
user_policies += attached_policies
user_groups = iam_backend.get_groups_for_user(self._owner_user_name)
for user_group in user_groups:
inline_group_policy_names = iam_backend.list_group_policies(user_group.name)
for inline_group_policy_name in inline_group_policy_names:
inline_user_group_policy = iam_backend.get_group_policy(
user_group.name, inline_group_policy_name
)
user_policies.append(inline_user_group_policy)
attached_group_policies, _ = iam_backend.list_attached_group_policies(
user_group.name
)
user_policies += attached_group_policies
return user_policies
class AssumedRoleAccessKey(object):
def __init__(self, access_key_id, headers):
for assumed_role in sts_backend.assumed_roles:
if assumed_role.access_key_id == access_key_id:
self._access_key_id = access_key_id
self._secret_access_key = assumed_role.secret_access_key
self._session_token = assumed_role.session_token
self._owner_role_name = assumed_role.role_arn.split("/")[-1]
self._session_name = assumed_role.session_name
if headers["X-Amz-Security-Token"] != self._session_token:
raise CreateAccessKeyFailure(reason="InvalidToken")
return
raise CreateAccessKeyFailure(reason="InvalidId")
@property
def arn(self):
return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format(
account_id=ACCOUNT_ID,
role_name=self._owner_role_name,
session_name=self._session_name,
)
def create_credentials(self):
return Credentials(
self._access_key_id, self._secret_access_key, self._session_token
)
def collect_policies(self):
role_policies = []
inline_policy_names = iam_backend.list_role_policies(self._owner_role_name)
for inline_policy_name in inline_policy_names:
_, inline_policy = iam_backend.get_role_policy(
self._owner_role_name, inline_policy_name
)
role_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_role_policies(
self._owner_role_name
)
role_policies += attached_policies
return role_policies
class CreateAccessKeyFailure(Exception):
def __init__(self, reason, *args):
super(CreateAccessKeyFailure, self).__init__(*args)
self.reason = reason
@six.add_metaclass(ABCMeta)
class IAMRequestBase(object):
def __init__(self, method, path, data, headers):
log.debug(
"Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format(
class_name=self.__class__.__name__,
method=method,
path=path,
data=data,
headers=headers,
)
)
self._method = method
self._path = path
self._data = data
self._headers = headers
credential_scope = self._get_string_between(
"Credential=", ",", self._headers["Authorization"]
)
credential_data = credential_scope.split("/")
self._region = credential_data[2]
self._service = credential_data[3]
self._action = (
self._service
+ ":"
+ (
self._data["Action"][0]
if isinstance(self._data["Action"], list)
else self._data["Action"]
)
)
try:
self._access_key = create_access_key(
access_key_id=credential_data[0], headers=headers
)
except CreateAccessKeyFailure as e:
self._raise_invalid_access_key(e.reason)
def check_signature(self):
original_signature = self._get_string_between(
"Signature=", ",", self._headers["Authorization"]
)
calculated_signature = self._calculate_signature()
if original_signature != calculated_signature:
self._raise_signature_does_not_match()
def check_action_permitted(self):
if (
self._action == "sts:GetCallerIdentity"
): # always allowed, even if there's an explicit Deny for it
return True
policies = self._access_key.collect_policies()
permitted = False
for policy in policies:
iam_policy = IAMPolicy(policy)
permission_result = iam_policy.is_action_permitted(self._action)
if permission_result == PermissionResult.DENIED:
self._raise_access_denied()
elif permission_result == PermissionResult.PERMITTED:
permitted = True
if not permitted:
self._raise_access_denied()
@abstractmethod
def _raise_signature_does_not_match(self):
raise NotImplementedError()
@abstractmethod
def _raise_access_denied(self):
raise NotImplementedError()
@abstractmethod
def _raise_invalid_access_key(self, reason):
raise NotImplementedError()
@abstractmethod
def _create_auth(self, credentials):
raise NotImplementedError()
@staticmethod
def _create_headers_for_aws_request(signed_headers, original_headers):
headers = {}
for key, value in original_headers.items():
if key.lower() in signed_headers:
headers[key] = value
return headers
def _create_aws_request(self):
signed_headers = self._get_string_between(
"SignedHeaders=", ",", self._headers["Authorization"]
).split(";")
headers = self._create_headers_for_aws_request(signed_headers, self._headers)
request = AWSRequest(
method=self._method, url=self._path, data=self._data, headers=headers
)
request.context["timestamp"] = headers["X-Amz-Date"]
return request
def _calculate_signature(self):
credentials = self._access_key.create_credentials()
auth = self._create_auth(credentials)
request = self._create_aws_request()
canonical_request = auth.canonical_request(request)
string_to_sign = auth.string_to_sign(request, canonical_request)
return auth.signature(string_to_sign, request)
@staticmethod
def _get_string_between(first_separator, second_separator, string):
return string.partition(first_separator)[2].partition(second_separator)[0]
class IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self):
if self._service == "ec2":
raise AuthFailureError()
else:
raise SignatureDoesNotMatchError()
def _raise_invalid_access_key(self, _):
if self._service == "ec2":
raise AuthFailureError()
else:
raise InvalidClientTokenIdError()
def _create_auth(self, credentials):
return SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self):
raise AccessDeniedError(user_arn=self._access_key.arn, action=self._action)
class S3IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self):
if "BucketName" in self._data:
raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"])
else:
raise S3SignatureDoesNotMatchError()
def _raise_invalid_access_key(self, reason):
if reason == "InvalidToken":
if "BucketName" in self._data:
raise BucketInvalidTokenError(bucket=self._data["BucketName"])
else:
raise S3InvalidTokenError()
else:
if "BucketName" in self._data:
raise BucketInvalidAccessKeyIdError(bucket=self._data["BucketName"])
else:
raise S3InvalidAccessKeyIdError()
def _create_auth(self, credentials):
return S3SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self):
if "BucketName" in self._data:
raise BucketAccessDeniedError(bucket=self._data["BucketName"])
else:
raise S3AccessDeniedError()
class IAMPolicy(object):
def __init__(self, policy):
if isinstance(policy, Policy):
default_version = next(
policy_version
for policy_version in policy.versions
if policy_version.is_default
)
policy_document = default_version.document
elif isinstance(policy, string_types):
policy_document = policy
else:
policy_document = policy["policy_document"]
self._policy_json = json.loads(policy_document)
def is_action_permitted(self, action):
permitted = False
if isinstance(self._policy_json["Statement"], list):
for policy_statement in self._policy_json["Statement"]:
iam_policy_statement = IAMPolicyStatement(policy_statement)
permission_result = iam_policy_statement.is_action_permitted(action)
if permission_result == PermissionResult.DENIED:
return permission_result
elif permission_result == PermissionResult.PERMITTED:
permitted = True
else: # dict
iam_policy_statement = IAMPolicyStatement(self._policy_json["Statement"])
return iam_policy_statement.is_action_permitted(action)
if permitted:
return PermissionResult.PERMITTED
else:
return PermissionResult.NEUTRAL
class IAMPolicyStatement(object):
def __init__(self, statement):
self._statement = statement
def is_action_permitted(self, action):
is_action_concerned = False
if "NotAction" in self._statement:
if not self._check_element_matches("NotAction", action):
is_action_concerned = True
else: # Action is present
if self._check_element_matches("Action", action):
is_action_concerned = True
if is_action_concerned:
if self._statement["Effect"] == "Allow":
return PermissionResult.PERMITTED
else: # Deny
return PermissionResult.DENIED
else:
return PermissionResult.NEUTRAL
def _check_element_matches(self, statement_element, value):
if isinstance(self._statement[statement_element], list):
for statement_element_value in self._statement[statement_element]:
if self._match(statement_element_value, value):
return True
return False
else: # string
return self._match(self._statement[statement_element], value)
@staticmethod
def _match(pattern, string):
pattern = pattern.replace("*", ".*")
pattern = "^{pattern}$".format(pattern=pattern)
return re.match(pattern, string)
class PermissionResult(Enum):
PERMITTED = 1
DENIED = 2
NEUTRAL = 3

View File

@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment
SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>{{error_type}}</Code>
<Message>{{message}}</Message>
@ -13,8 +13,8 @@ SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error>
"""
ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
<Response>
ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse>
<Errors>
<Error>
<Code>{{error_type}}</Code>
@ -23,10 +23,10 @@ ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error>
</Errors>
<RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID>
</Response>
</ErrorResponse>
"""
ERROR_JSON_RESPONSE = u"""{
ERROR_JSON_RESPONSE = """{
"message": "{{message}}",
"__type": "{{error_type}}"
}
@ -37,18 +37,19 @@ class RESTError(HTTPException):
code = 400
templates = {
'single_error': SINGLE_ERROR_RESPONSE,
'error': ERROR_RESPONSE,
'error_json': ERROR_JSON_RESPONSE,
"single_error": SINGLE_ERROR_RESPONSE,
"error": ERROR_RESPONSE,
"error_json": ERROR_JSON_RESPONSE,
}
def __init__(self, error_type, message, template='error', **kwargs):
def __init__(self, error_type, message, template="error", **kwargs):
super(RESTError, self).__init__()
env = Environment(loader=DictLoader(self.templates))
self.error_type = error_type
self.message = message
self.description = env.get_template(template).render(
error_type=error_type, message=message, **kwargs)
error_type=error_type, message=message, **kwargs
)
class DryRunClientError(RESTError):
@ -56,12 +57,64 @@ class DryRunClientError(RESTError):
class JsonRESTError(RESTError):
def __init__(self, error_type, message, template='error_json', **kwargs):
super(JsonRESTError, self).__init__(
error_type, message, template, **kwargs)
def __init__(self, error_type, message, template="error_json", **kwargs):
super(JsonRESTError, self).__init__(error_type, message, template, **kwargs)
def get_headers(self, *args, **kwargs):
return [('Content-Type', 'application/json')]
return [("Content-Type", "application/json")]
def get_body(self, *args, **kwargs):
return self.description
class SignatureDoesNotMatchError(RESTError):
code = 403
def __init__(self):
super(SignatureDoesNotMatchError, self).__init__(
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.",
)
class InvalidClientTokenIdError(RESTError):
code = 403
def __init__(self):
super(InvalidClientTokenIdError, self).__init__(
"InvalidClientTokenId",
"The security token included in the request is invalid.",
)
class AccessDeniedError(RESTError):
code = 403
def __init__(self, user_arn, action):
super(AccessDeniedError, self).__init__(
"AccessDenied",
"User: {user_arn} is not authorized to perform: {operation}".format(
user_arn=user_arn, operation=action
),
)
class AuthFailureError(RESTError):
code = 401
def __init__(self):
super(AuthFailureError, self).__init__(
"AuthFailure",
"AWS was not able to validate the provided access credentials",
)
class InvalidNextTokenException(JsonRESTError):
"""For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core."""
code = 400
def __init__(self):
super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", "The nextToken provided is invalid"
)

View File

@ -7,6 +7,7 @@ import inspect
import os
import re
import six
import types
from io import BytesIO
from collections import defaultdict
from botocore.handlers import BUILTIN_HANDLERS
@ -23,6 +24,9 @@ from .utils import (
)
ACCOUNT_ID = os.environ.get("MOTO_ACCOUNT_ID", "123456789012")
class BaseMockAWS(object):
nested_count = 0
@ -31,15 +35,20 @@ class BaseMockAWS(object):
self.backends_for_urls = {}
from moto.backends import BACKENDS
default_backends = {
"instance_metadata": BACKENDS['instance_metadata']['global'],
"moto_api": BACKENDS['moto_api']['global'],
"instance_metadata": BACKENDS["instance_metadata"]["global"],
"moto_api": BACKENDS["moto_api"]["global"],
}
self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"}
FAKE_KEYS = {
"AWS_ACCESS_KEY_ID": "foobar_key",
"AWS_SECRET_ACCESS_KEY": "foobar_secret",
}
self.default_session_mock = mock.patch("boto3.DEFAULT_SESSION", None)
self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
if self.__class__.nested_count == 0:
@ -58,6 +67,7 @@ class BaseMockAWS(object):
self.stop()
def start(self, reset=True):
self.default_session_mock.start()
self.env_variables_mocks.start()
self.__class__.nested_count += 1
@ -68,11 +78,12 @@ class BaseMockAWS(object):
self.enable_patching()
def stop(self):
self.default_session_mock.stop()
self.env_variables_mocks.stop()
self.__class__.nested_count -= 1
if self.__class__.nested_count < 0:
raise RuntimeError('Called stop() before start().')
raise RuntimeError("Called stop() before start().")
if self.__class__.nested_count == 0:
self.disable_patching()
@ -85,6 +96,7 @@ class BaseMockAWS(object):
finally:
self.stop()
return result
functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func
return wrapper
@ -122,7 +134,6 @@ class BaseMockAWS(object):
class HttprettyMockAWS(BaseMockAWS):
def reset(self):
HTTPretty.reset()
@ -144,18 +155,26 @@ class HttprettyMockAWS(BaseMockAWS):
HTTPretty.reset()
RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD,
responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT]
RESPONSES_METHODS = [
responses.GET,
responses.DELETE,
responses.HEAD,
responses.OPTIONS,
responses.PATCH,
responses.POST,
responses.PUT,
]
class CallbackResponse(responses.CallbackResponse):
'''
"""
Need to subclass so we can change a couple things
'''
"""
def get_response(self, request):
'''
"""
Need to override this so we can pass decode_content=False
'''
"""
headers = self.get_headers()
result = self.callback(request)
@ -177,17 +196,17 @@ class CallbackResponse(responses.CallbackResponse):
)
def _url_matches(self, url, other, match_querystring=False):
'''
"""
Need to override this so we can fix querystrings breaking regex matching
'''
"""
if not match_querystring:
other = other.split('?', 1)[0]
other = other.split("?", 1)[0]
if responses._is_string(url):
if responses._has_unicode(url):
url = responses._clean_unicode(url)
if not isinstance(other, six.text_type):
other = other.encode('ascii').decode('utf8')
other = other.encode("ascii").decode("utf8")
return self._url_matches_strict(url, other)
elif isinstance(url, responses.Pattern) and url.match(other):
return True
@ -195,66 +214,40 @@ class CallbackResponse(responses.CallbackResponse):
return False
botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send')
botocore_mock = responses.RequestsMock(
assert_all_requests_are_fired=False,
target="botocore.vendored.requests.adapters.HTTPAdapter.send",
)
responses_mock = responses._default_mock
# Add passthrough to allow any other requests to work
# Since this uses .startswith, it applies to http and https requests.
responses_mock.add_passthru("http")
class ResponsesMockAWS(BaseMockAWS):
def reset(self):
botocore_mock.reset()
responses_mock.reset()
def _find_first_match(self, request):
for i, match in enumerate(self._matches):
if match.matches(request):
return match
def enable_patching(self):
if not hasattr(botocore_mock, '_patcher') or not hasattr(botocore_mock._patcher, 'target'):
# Check for unactivated patcher
botocore_mock.start()
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'):
responses_mock.start()
for method in RESPONSES_METHODS:
for backend in self.backends_for_urls.values():
for key, value in backend.urls.items():
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile(key),
callback=convert_flask_to_responses_response(value),
stream=True,
match_querystring=False,
)
)
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile(key),
callback=convert_flask_to_responses_response(value),
stream=True,
match_querystring=False,
)
)
def disable_patching(self):
try:
botocore_mock.stop()
except RuntimeError:
pass
try:
responses_mock.stop()
except RuntimeError:
pass
return None
BOTOCORE_HTTP_METHODS = [
'GET', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'
]
# Modify behaviour of the matcher to only/always return the first match
# Default behaviour is to return subsequent matches for subsequent requests, which leads to https://github.com/spulec/moto/issues/2567
# - First request matches on the appropriate S3 URL
# - Same request, executed again, will be matched on the subsequent match, which happens to be the catch-all, not-yet-implemented, callback
# Fix: Always return the first match
responses_mock._find_match = types.MethodType(_find_first_match, responses_mock)
BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
class MockRawResponse(BytesIO):
def __init__(self, input):
if isinstance(input, six.text_type):
input = input.encode('utf-8')
input = input.encode("utf-8")
super(MockRawResponse, self).__init__(input)
def stream(self, **kwargs):
@ -285,7 +278,7 @@ class BotocoreStubber(object):
found_index = None
matchers = self.methods.get(request.method)
base_url = request.url.split('?', 1)[0]
base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers):
if pattern.match(base_url):
if found_index is None:
@ -298,8 +291,10 @@ class BotocoreStubber(object):
if response_callback is not None:
for header, value in request.headers.items():
if isinstance(value, six.binary_type):
request.headers[header] = value.decode('utf-8')
status, headers, body = response_callback(request, request.url, request.headers)
request.headers[header] = value.decode("utf-8")
status, headers, body = response_callback(
request, request.url, request.headers
)
body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body)
@ -307,7 +302,15 @@ class BotocoreStubber(object):
botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(('before-send', botocore_stubber))
BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
def not_implemented_callback(request):
status = 400
headers = {}
response = "The method is not implemented"
return status, headers, response
class BotocoreEventMockAWS(BaseMockAWS):
@ -323,7 +326,9 @@ class BotocoreEventMockAWS(BaseMockAWS):
pattern = re.compile(key)
botocore_stubber.register_response(method, pattern, value)
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'):
if not hasattr(responses_mock, "_patcher") or not hasattr(
responses_mock._patcher, "target"
):
responses_mock.start()
for method in RESPONSES_METHODS:
@ -339,6 +344,24 @@ class BotocoreEventMockAWS(BaseMockAWS):
match_querystring=False,
)
)
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile(r"https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
)
)
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile(r"https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
)
)
def disable_patching(self):
botocore_stubber.enabled = False
@ -354,9 +377,9 @@ MockAWS = BotocoreEventMockAWS
class ServerModeMockAWS(BaseMockAWS):
def reset(self):
import requests
requests.post("http://localhost:5000/moto-api/reset")
def enable_patching(self):
@ -368,13 +391,13 @@ class ServerModeMockAWS(BaseMockAWS):
import mock
def fake_boto3_client(*args, **kwargs):
if 'endpoint_url' not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000"
if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_client(*args, **kwargs)
def fake_boto3_resource(*args, **kwargs):
if 'endpoint_url' not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000"
if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_resource(*args, **kwargs)
def fake_httplib_send_output(self, message_body=None, *args, **kwargs):
@ -382,7 +405,7 @@ class ServerModeMockAWS(BaseMockAWS):
bytes_buffer = []
for chunk in mixed_buffer:
if isinstance(chunk, six.text_type):
bytes_buffer.append(chunk.encode('utf-8'))
bytes_buffer.append(chunk.encode("utf-8"))
else:
bytes_buffer.append(chunk)
msg = b"\r\n".join(bytes_buffer)
@ -403,10 +426,12 @@ class ServerModeMockAWS(BaseMockAWS):
if message_body is not None:
self.send(message_body)
self._client_patcher = mock.patch('boto3.client', fake_boto3_client)
self._resource_patcher = mock.patch('boto3.resource', fake_boto3_resource)
self._client_patcher = mock.patch("boto3.client", fake_boto3_client)
self._resource_patcher = mock.patch("boto3.resource", fake_boto3_resource)
if six.PY2:
self._httplib_patcher = mock.patch('httplib.HTTPConnection._send_output', fake_httplib_send_output)
self._httplib_patcher = mock.patch(
"httplib.HTTPConnection._send_output", fake_httplib_send_output
)
self._client_patcher.start()
self._resource_patcher.start()
@ -422,7 +447,6 @@ class ServerModeMockAWS(BaseMockAWS):
class Model(type):
def __new__(self, clsname, bases, namespace):
cls = super(Model, self).__new__(self, clsname, bases, namespace)
cls.__models__ = {}
@ -437,9 +461,11 @@ class Model(type):
@staticmethod
def prop(model_name):
""" decorator to mark a class method as returning model values """
def dec(f):
f.__returns_model__ = model_name
return f
return dec
@ -449,7 +475,7 @@ model_data = defaultdict(dict)
class InstanceTrackerMeta(type):
def __new__(meta, name, bases, dct):
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
if name == 'BaseModel':
if name == "BaseModel":
return cls
service = cls.__module__.split(".")[1]
@ -468,7 +494,6 @@ class BaseModel(object):
class BaseBackend(object):
def _reset_model_refs(self):
# Remove all references to the models stored
for service, models in model_data.items():
@ -484,8 +509,9 @@ class BaseBackend(object):
def _url_module(self):
backend_module = self.__class__.__module__
backend_urls_module_name = backend_module.replace("models", "urls")
backend_urls_module = __import__(backend_urls_module_name, fromlist=[
'url_bases', 'url_paths'])
backend_urls_module = __import__(
backend_urls_module_name, fromlist=["url_bases", "url_paths"]
)
return backend_urls_module
@property
@ -541,9 +567,9 @@ class BaseBackend(object):
def decorator(self, func=None):
if settings.TEST_SERVER_MODE:
mocked_backend = ServerModeMockAWS({'global': self})
mocked_backend = ServerModeMockAWS({"global": self})
else:
mocked_backend = MockAWS({'global': self})
mocked_backend = MockAWS({"global": self})
if func:
return mocked_backend(func)
@ -552,9 +578,101 @@ class BaseBackend(object):
def deprecated_decorator(self, func=None):
if func:
return HttprettyMockAWS({'global': self})(func)
return HttprettyMockAWS({"global": self})(func)
else:
return HttprettyMockAWS({'global': self})
return HttprettyMockAWS({"global": self})
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
# raise NotImplementedError()
class ConfigQueryModel(object):
def __init__(self, backends):
"""Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends
def list_config_service_resources(
self,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
):
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region.
This supports both aggregated and non-aggregated listing. The following notes the difference:
- Non-Aggregated Listing -
This only lists resources within a region. The way that this is implemented in moto is based on the region
for the resource backend.
You must set the `backend_region` to the region that the API request arrived from. resource_region can be set to `None`.
- Aggregated Listing -
This lists resources from all potential regional backends. For non-global resource types, this should collect a full
list of resources from all the backends, and then be able to filter from the resource region. This is because an
aggregator can aggregate resources from multiple regions. In moto, aggregated regions will *assume full aggregation
from all resources in all regions for a given resource type*.
The `backend_region` should be set to `None` for these queries, and the `resource_region` should optionally be set to
the `Filters` region parameter to filter out resources that reside in a specific region.
For aggregated listings, pagination logic should be set such that the next page can properly span all the region backends.
As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter
from there. It may be valuable to make this a concatenation of the region and resource name.
:param resource_ids: A list of resource IDs
:param resource_name: The individual name of a resource
:param limit: How many per page
:param next_token: The item that will page on
:param backend_region: The region for the backend to pull results from. Set to `None` if this is an aggregated query.
:param resource_region: The region for where the resources reside to pull results from. Set to `None` if this is a
non-aggregated query.
:return: This should return a list of Dicts that have the following fields:
[
{
'type': 'AWS::The AWS Config data type',
'name': 'The name of the resource',
'id': 'The ID of the resource',
'region': 'The region of the resource -- if global, then you may want to have the calling logic pass in the
aggregator region in for the resource region -- or just us-east-1 :P'
}
, ...
]
"""
raise NotImplementedError()
def get_config_resource(
self, resource_id, resource_name=None, backend_region=None, resource_region=None
):
"""For AWS Config. This will query the backend for the specific resource type configuration.
This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests
will call this function N times to fetch the N objects needing to be fetched.
- Non-Aggregated Fetching -
This only fetches a resource config within a region. The way that this is implemented in moto is based on the region
for the resource backend.
You must set the `backend_region` to the region that the API request arrived from. `resource_region` should be set to `None`.
- Aggregated Fetching -
This fetches resources from all potential regional backends. For non-global resource types, this should collect a full
list of resources from all the backends, and then be able to filter from the resource region. This is because an
aggregator can aggregate resources from multiple regions. In moto, aggregated regions will *assume full aggregation
from all resources in all regions for a given resource type*.
...
:param resource_id:
:param resource_name:
:param backend_region:
:param resource_region:
:return:
"""
raise NotImplementedError()
class base_decorator(object):
@ -580,9 +698,9 @@ class deprecated_base_decorator(base_decorator):
class MotoAPIBackend(BaseBackend):
def reset(self):
from moto.backends import BACKENDS
for name, backends in BACKENDS.items():
if name == "moto_api":
continue

View File

@ -1,13 +1,17 @@
from __future__ import unicode_literals
import functools
from collections import defaultdict
import datetime
import json
import logging
import re
import io
import requests
import pytz
from moto.core.access_control import IAMRequest, S3IAMRequest
from moto.core.exceptions import DryRunClientError
from jinja2 import Environment, DictLoader, TemplateNotFound
@ -22,7 +26,7 @@ from werkzeug.exceptions import HTTPException
import boto3
from moto.compat import OrderedDict
from moto.core.utils import camelcase_to_underscores, method_names_from_class
from moto import settings
log = logging.getLogger(__name__)
@ -36,7 +40,7 @@ def _decode_dict(d):
newkey = []
for k in key:
if isinstance(k, six.binary_type):
newkey.append(k.decode('utf-8'))
newkey.append(k.decode("utf-8"))
else:
newkey.append(k)
else:
@ -48,7 +52,7 @@ def _decode_dict(d):
newvalue = []
for v in value:
if isinstance(v, six.binary_type):
newvalue.append(v.decode('utf-8'))
newvalue.append(v.decode("utf-8"))
else:
newvalue.append(v)
else:
@ -79,12 +83,15 @@ class DynamicDictLoader(DictLoader):
class _TemplateEnvironmentMixin(object):
LEFT_PATTERN = re.compile(r"[\s\n]+<")
RIGHT_PATTERN = re.compile(r">[\s\n]+")
def __init__(self):
super(_TemplateEnvironmentMixin, self).__init__()
self.loader = DynamicDictLoader({})
self.environment = Environment(
loader=self.loader, autoescape=self.should_autoescape)
loader=self.loader, autoescape=self.should_autoescape
)
@property
def should_autoescape(self):
@ -97,19 +104,92 @@ class _TemplateEnvironmentMixin(object):
def response_template(self, source):
template_id = id(source)
if not self.contains_template(template_id):
self.loader.update({template_id: source})
self.environment = Environment(loader=self.loader, autoescape=self.should_autoescape, trim_blocks=True,
lstrip_blocks=True)
collapsed = re.sub(
self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source)
)
self.loader.update({template_id: collapsed})
self.environment = Environment(
loader=self.loader,
autoescape=self.should_autoescape,
trim_blocks=True,
lstrip_blocks=True,
)
return self.environment.get_template(template_id)
class BaseResponse(_TemplateEnvironmentMixin):
class ActionAuthenticatorMixin(object):
default_region = 'us-east-1'
request_count = 0
def _authenticate_and_authorize_action(self, iam_request_cls):
if (
ActionAuthenticatorMixin.request_count
>= settings.INITIAL_NO_AUTH_ACTION_COUNT
):
iam_request = iam_request_cls(
method=self.method, path=self.path, data=self.data, headers=self.headers
)
iam_request.check_signature()
iam_request.check_action_permitted()
else:
ActionAuthenticatorMixin.request_count += 1
def _authenticate_and_authorize_normal_action(self):
self._authenticate_and_authorize_action(IAMRequest)
def _authenticate_and_authorize_s3_action(self):
self._authenticate_and_authorize_action(S3IAMRequest)
@staticmethod
def set_initial_no_auth_action_count(initial_no_auth_action_count):
def decorator(function):
def wrapper(*args, **kwargs):
if settings.TEST_SERVER_MODE:
response = requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(initial_no_auth_action_count).encode(),
)
original_initial_no_auth_action_count = response.json()[
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT"
]
else:
original_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
original_request_count = ActionAuthenticatorMixin.request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count
ActionAuthenticatorMixin.request_count = 0
try:
result = function(*args, **kwargs)
finally:
if settings.TEST_SERVER_MODE:
requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(original_initial_no_auth_action_count).encode(),
)
else:
ActionAuthenticatorMixin.request_count = original_request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = (
original_initial_no_auth_action_count
)
return result
functools.update_wrapper(wrapper, function)
wrapper.__wrapped__ = function
return wrapper
return decorator
class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = "us-east-1"
# to extract region, use [^.]
region_regex = re.compile(r'\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com')
param_list_regex = re.compile(r'(.*)\.(\d+)\.')
access_key_regex = re.compile(r'AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]')
region_regex = re.compile(r"\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com")
param_list_regex = re.compile(r"(.*)\.(\d+)\.")
access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
)
aws_service_spec = None
@classmethod
@ -118,7 +198,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
def setup_class(self, request, full_url, headers):
querystring = {}
if hasattr(request, 'body'):
if hasattr(request, "body"):
# Boto
self.body = request.body
else:
@ -131,24 +211,29 @@ class BaseResponse(_TemplateEnvironmentMixin):
querystring = {}
for key, value in request.form.items():
querystring[key] = [value, ]
querystring[key] = [value]
raw_body = self.body
if isinstance(self.body, six.binary_type):
self.body = self.body.decode('utf-8')
self.body = self.body.decode("utf-8")
if not querystring:
querystring.update(
parse_qs(urlparse(full_url).query, keep_blank_values=True))
parse_qs(urlparse(full_url).query, keep_blank_values=True)
)
if not querystring:
if 'json' in request.headers.get('content-type', []) and self.aws_service_spec:
if (
"json" in request.headers.get("content-type", [])
and self.aws_service_spec
):
decoded = json.loads(self.body)
target = request.headers.get(
'x-amz-target') or request.headers.get('X-Amz-Target')
service, method = target.split('.')
target = request.headers.get("x-amz-target") or request.headers.get(
"X-Amz-Target"
)
service, method = target.split(".")
input_spec = self.aws_service_spec.input_spec(method)
flat = flatten_json_request_body('', decoded, input_spec)
flat = flatten_json_request_body("", decoded, input_spec)
for key, value in flat.items():
querystring[key] = [value]
elif self.body:
@ -167,22 +252,25 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.uri = full_url
self.path = urlparse(full_url).path
self.querystring = querystring
self.data = querystring
self.method = request.method
self.region = self.get_region_from_url(request, full_url)
self.uri_match = None
self.headers = request.headers
if 'host' not in self.headers:
self.headers['host'] = urlparse(full_url).netloc
if "host" not in self.headers:
self.headers["host"] = urlparse(full_url).netloc
self.response_headers = {"server": "amazon.com"}
def get_region_from_url(self, request, full_url):
match = self.region_regex.search(full_url)
if match:
region = match.group(1)
elif 'Authorization' in request.headers and 'AWS4' in request.headers['Authorization']:
region = request.headers['Authorization'].split(",")[
0].split("/")[2]
elif (
"Authorization" in request.headers
and "AWS4" in request.headers["Authorization"]
):
region = request.headers["Authorization"].split(",")[0].split("/")[2]
else:
region = self.default_region
return region
@ -191,16 +279,16 @@ class BaseResponse(_TemplateEnvironmentMixin):
"""
Returns the access key id used in this request as the current user id
"""
if 'Authorization' in self.headers:
match = self.access_key_regex.search(self.headers['Authorization'])
if "Authorization" in self.headers:
match = self.access_key_regex.search(self.headers["Authorization"])
if match:
return match.group(1)
if self.querystring.get('AWSAccessKeyId'):
return self.querystring.get('AWSAccessKeyId')
if self.querystring.get("AWSAccessKeyId"):
return self.querystring.get("AWSAccessKeyId")
else:
# Should we raise an unauthorized exception instead?
return '111122223333'
return "111122223333"
def _dispatch(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -215,17 +303,22 @@ class BaseResponse(_TemplateEnvironmentMixin):
-> '^/cars/.*/drivers/.*/drive$'
"""
def _convert(elem, is_last):
if not re.match('^{.*}$', elem):
return elem
name = elem.replace('{', '').replace('}', '')
if is_last:
return '(?P<%s>[^/]*)' % name
return '(?P<%s>.*)' % name
elems = uri.split('/')
def _convert(elem, is_last):
if not re.match("^{.*}$", elem):
return elem
name = elem.replace("{", "").replace("}", "").replace("+", "")
if is_last:
return "(?P<%s>[^/]*)" % name
return "(?P<%s>.*)" % name
elems = uri.split("/")
num_elems = len(elems)
regexp = '^{}$'.format('/'.join([_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)]))
regexp = "^{}$".format(
"/".join(
[_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)]
)
)
return regexp
def _get_action_from_method_and_request_uri(self, method, request_uri):
@ -236,19 +329,19 @@ class BaseResponse(_TemplateEnvironmentMixin):
# service response class should have 'SERVICE_NAME' class member,
# if you want to get action from method and url
if not hasattr(self, 'SERVICE_NAME'):
if not hasattr(self, "SERVICE_NAME"):
return None
service = self.SERVICE_NAME
conn = boto3.client(service, region_name=self.region)
# make cache if it does not exist yet
if not hasattr(self, 'method_urls'):
if not hasattr(self, "method_urls"):
self.method_urls = defaultdict(lambda: defaultdict(str))
op_names = conn._service_model.operation_names
for op_name in op_names:
op_model = conn._service_model.operation_model(op_name)
_method = op_model.http['method']
uri_regexp = self.uri_to_regexp(op_model.http['requestUri'])
_method = op_model.http["method"]
uri_regexp = self.uri_to_regexp(op_model.http["requestUri"])
self.method_urls[_method][uri_regexp] = op_model.name
regexp_and_names = self.method_urls[method]
for regexp, name in regexp_and_names.items():
@ -259,11 +352,10 @@ class BaseResponse(_TemplateEnvironmentMixin):
return None
def _get_action(self):
action = self.querystring.get('Action', [""])[0]
action = self.querystring.get("Action", [""])[0]
if not action: # Some services use a header for the action
# Headers are case-insensitive. Probably a better way to do this.
match = self.headers.get(
'x-amz-target') or self.headers.get('X-Amz-Target')
match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target")
if match:
action = match.split(".")[-1]
# get action from method and uri
@ -273,6 +365,13 @@ class BaseResponse(_TemplateEnvironmentMixin):
def call_action(self):
headers = self.response_headers
try:
self._authenticate_and_authorize_normal_action()
except HTTPException as http_error:
response = http_error.description, dict(status=http_error.code)
return self._send_response(headers, response)
action = camelcase_to_underscores(self._get_action())
method_names = method_names_from_class(self.__class__)
if action in method_names:
@ -285,22 +384,27 @@ class BaseResponse(_TemplateEnvironmentMixin):
if isinstance(response, six.string_types):
return 200, headers, response
else:
if len(response) == 2:
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
return status, headers, body
return self._send_response(headers, response)
if not action:
return 404, headers, ''
return 404, headers, ""
raise NotImplementedError(
"The {0} action has not been implemented".format(action))
"The {0} action has not been implemented".format(action)
)
@staticmethod
def _send_response(headers, response):
if len(response) == 2:
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get("status", 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers["status"] = str(headers["status"])
return status, headers, body
def _get_param(self, param_name, if_none=None):
val = self.querystring.get(param_name)
@ -333,9 +437,9 @@ class BaseResponse(_TemplateEnvironmentMixin):
def _get_bool_param(self, param_name, if_none=None):
val = self._get_param(param_name)
if val is not None:
if val.lower() == 'true':
if val.lower() == "true":
return True
elif val.lower() == 'false':
elif val.lower() == "false":
return False
return if_none
@ -353,11 +457,16 @@ class BaseResponse(_TemplateEnvironmentMixin):
if is_tracked(name) or not name.startswith(param_prefix):
continue
if len(name) > len(param_prefix) and \
not name[len(param_prefix):].startswith('.'):
if len(name) > len(param_prefix) and not name[
len(param_prefix) :
].startswith("."):
continue
match = self.param_list_regex.search(name[len(param_prefix):]) if len(name) > len(param_prefix) else None
match = (
self.param_list_regex.search(name[len(param_prefix) :])
if len(name) > len(param_prefix)
else None
)
if match:
prefix = param_prefix + match.group(1)
value = self._get_multi_param(prefix)
@ -372,7 +481,10 @@ class BaseResponse(_TemplateEnvironmentMixin):
if len(value_dict) > 1:
# strip off period prefix
value_dict = {name[len(param_prefix) + 1:]: value for name, value in value_dict.items()}
value_dict = {
name[len(param_prefix) + 1 :]: value
for name, value in value_dict.items()
}
else:
value_dict = list(value_dict.values())[0]
@ -391,7 +503,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
index = 1
while True:
value_dict = self._get_multi_param_helper(prefix + str(index))
if not value_dict:
if not value_dict and value_dict != "":
break
values.append(value_dict)
@ -416,8 +528,9 @@ class BaseResponse(_TemplateEnvironmentMixin):
params = {}
for key, value in self.querystring.items():
if key.startswith(param_prefix):
params[camelcase_to_underscores(
key.replace(param_prefix, ""))] = value[0]
params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[
0
]
return params
def _get_list_prefix(self, param_prefix):
@ -450,19 +563,20 @@ class BaseResponse(_TemplateEnvironmentMixin):
new_items = {}
for key, value in self.querystring.items():
if key.startswith(index_prefix):
new_items[camelcase_to_underscores(
key.replace(index_prefix, ""))] = value[0]
new_items[
camelcase_to_underscores(key.replace(index_prefix, ""))
] = value[0]
if not new_items:
break
results.append(new_items)
param_index += 1
return results
def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'):
def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"):
results = {}
param_index = 1
while 1:
index_prefix = '{0}.{1}.'.format(param_prefix, param_index)
index_prefix = "{0}.{1}.".format(param_prefix, param_index)
k, v = None, None
for key, value in self.querystring.items():
@ -489,8 +603,8 @@ class BaseResponse(_TemplateEnvironmentMixin):
param_index = 1
while True:
key_name = 'tag.{0}._key'.format(param_index)
value_name = 'tag.{0}._value'.format(param_index)
key_name = "tag.{0}._key".format(param_index)
value_name = "tag.{0}._value".format(param_index)
try:
results[resource_type][tag[key_name]] = tag[value_name]
@ -500,7 +614,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
return results
def _get_object_map(self, prefix, name='Name', value='Value'):
def _get_object_map(self, prefix, name="Name", value="Value"):
"""
Given a query dict like
{
@ -528,15 +642,14 @@ class BaseResponse(_TemplateEnvironmentMixin):
index = 1
while True:
# Loop through looking for keys representing object name
name_key = '{0}.{1}.{2}'.format(prefix, index, name)
name_key = "{0}.{1}.{2}".format(prefix, index, name)
obj_name = self.querystring.get(name_key)
if not obj_name:
# Found all keys
break
obj = {}
value_key_prefix = '{0}.{1}.{2}.'.format(
prefix, index, value)
value_key_prefix = "{0}.{1}.{2}.".format(prefix, index, value)
for k, v in self.querystring.items():
if k.startswith(value_key_prefix):
_, value_key = k.split(value_key_prefix, 1)
@ -550,25 +663,48 @@ class BaseResponse(_TemplateEnvironmentMixin):
@property
def request_json(self):
return 'JSON' in self.querystring.get('ContentType', [])
return "JSON" in self.querystring.get("ContentType", [])
def is_not_dryrun(self, action):
if 'true' in self.querystring.get('DryRun', ['false']):
message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action
raise DryRunClientError(
error_type="DryRunOperation", message=message)
if "true" in self.querystring.get("DryRun", ["false"]):
message = (
"An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set"
% action
)
raise DryRunClientError(error_type="DryRunOperation", message=message)
return True
class MotoAPIResponse(BaseResponse):
def reset_response(self, request, full_url, headers):
if request.method == "POST":
from .models import moto_api_backend
moto_api_backend.reset()
return 200, {}, json.dumps({"status": "ok"})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"})
def reset_auth_response(self, request, full_url, headers):
if request.method == "POST":
previous_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode())
ActionAuthenticatorMixin.request_count = 0
return (
200,
{},
json.dumps(
{
"status": "ok",
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(
previous_initial_no_auth_action_count
),
}
),
)
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
def model_data(self, request, full_url, headers):
from moto.core.models import model_data
@ -594,7 +730,8 @@ class MotoAPIResponse(BaseResponse):
def dashboard(self, request, full_url, headers):
from flask import render_template
return render_template('dashboard.html')
return render_template("dashboard.html")
class _RecursiveDictRef(object):
@ -605,7 +742,7 @@ class _RecursiveDictRef(object):
self.dic = {}
def __repr__(self):
return '{!r}'.format(self.dic)
return "{!r}".format(self.dic)
def __getattr__(self, key):
return self.dic.__getattr__(key)
@ -629,21 +766,21 @@ class AWSServiceSpec(object):
"""
def __init__(self, path):
self.path = resource_filename('botocore', path)
with io.open(self.path, 'r', encoding='utf-8') as f:
self.path = resource_filename("botocore", path)
with io.open(self.path, "r", encoding="utf-8") as f:
spec = json.load(f)
self.metadata = spec['metadata']
self.operations = spec['operations']
self.shapes = spec['shapes']
self.metadata = spec["metadata"]
self.operations = spec["operations"]
self.shapes = spec["shapes"]
def input_spec(self, operation):
try:
op = self.operations[operation]
except KeyError:
raise ValueError('Invalid operation: {}'.format(operation))
if 'input' not in op:
raise ValueError("Invalid operation: {}".format(operation))
if "input" not in op:
return {}
shape = self.shapes[op['input']['shape']]
shape = self.shapes[op["input"]["shape"]]
return self._expand(shape)
def output_spec(self, operation):
@ -657,129 +794,133 @@ class AWSServiceSpec(object):
try:
op = self.operations[operation]
except KeyError:
raise ValueError('Invalid operation: {}'.format(operation))
if 'output' not in op:
raise ValueError("Invalid operation: {}".format(operation))
if "output" not in op:
return {}
shape = self.shapes[op['output']['shape']]
shape = self.shapes[op["output"]["shape"]]
return self._expand(shape)
def _expand(self, shape):
def expand(dic, seen=None):
seen = seen or {}
if dic['type'] == 'structure':
if dic["type"] == "structure":
nodes = {}
for k, v in dic['members'].items():
for k, v in dic["members"].items():
seen_till_here = dict(seen)
if k in seen_till_here:
nodes[k] = seen_till_here[k]
continue
seen_till_here[k] = _RecursiveDictRef()
nodes[k] = expand(self.shapes[v['shape']], seen_till_here)
nodes[k] = expand(self.shapes[v["shape"]], seen_till_here)
seen_till_here[k].set_reference(k, nodes[k])
nodes['type'] = 'structure'
nodes["type"] = "structure"
return nodes
elif dic['type'] == 'list':
elif dic["type"] == "list":
seen_till_here = dict(seen)
shape = dic['member']['shape']
shape = dic["member"]["shape"]
if shape in seen_till_here:
return seen_till_here[shape]
seen_till_here[shape] = _RecursiveDictRef()
expanded = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, expanded)
return {'type': 'list', 'member': expanded}
return {"type": "list", "member": expanded}
elif dic['type'] == 'map':
elif dic["type"] == "map":
seen_till_here = dict(seen)
node = {'type': 'map'}
node = {"type": "map"}
if 'shape' in dic['key']:
shape = dic['key']['shape']
if "shape" in dic["key"]:
shape = dic["key"]["shape"]
seen_till_here[shape] = _RecursiveDictRef()
node['key'] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['key'])
node["key"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node["key"])
else:
node['key'] = dic['key']['type']
node["key"] = dic["key"]["type"]
if 'shape' in dic['value']:
shape = dic['value']['shape']
if "shape" in dic["value"]:
shape = dic["value"]["shape"]
seen_till_here[shape] = _RecursiveDictRef()
node['value'] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['value'])
node["value"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node["value"])
else:
node['value'] = dic['value']['type']
node["value"] = dic["value"]["type"]
return node
else:
return {'type': dic['type']}
return {"type": dic["type"]}
return expand(shape)
def to_str(value, spec):
vtype = spec['type']
if vtype == 'boolean':
return 'true' if value else 'false'
elif vtype == 'integer':
vtype = spec["type"]
if vtype == "boolean":
return "true" if value else "false"
elif vtype == "integer":
return str(value)
elif vtype == 'float':
elif vtype == "float":
return str(value)
elif vtype == 'double':
elif vtype == "double":
return str(value)
elif vtype == 'timestamp':
return datetime.datetime.utcfromtimestamp(
value).replace(tzinfo=pytz.utc).isoformat()
elif vtype == 'string':
elif vtype == "timestamp":
return (
datetime.datetime.utcfromtimestamp(value)
.replace(tzinfo=pytz.utc)
.isoformat()
)
elif vtype == "string":
return str(value)
elif value is None:
return 'null'
return "null"
else:
raise TypeError('Unknown type {}'.format(vtype))
raise TypeError("Unknown type {}".format(vtype))
def from_str(value, spec):
vtype = spec['type']
if vtype == 'boolean':
return True if value == 'true' else False
elif vtype == 'integer':
vtype = spec["type"]
if vtype == "boolean":
return True if value == "true" else False
elif vtype == "integer":
return int(value)
elif vtype == 'float':
elif vtype == "float":
return float(value)
elif vtype == 'double':
elif vtype == "double":
return float(value)
elif vtype == 'timestamp':
elif vtype == "timestamp":
return value
elif vtype == 'string':
elif vtype == "string":
return value
raise TypeError('Unknown type {}'.format(vtype))
raise TypeError("Unknown type {}".format(vtype))
def flatten_json_request_body(prefix, dict_body, spec):
"""Convert a JSON request body into query params."""
if len(spec) == 1 and 'type' in spec:
if len(spec) == 1 and "type" in spec:
return {prefix: to_str(dict_body, spec)}
flat = {}
for key, value in dict_body.items():
node_type = spec[key]['type']
if node_type == 'list':
node_type = spec[key]["type"]
if node_type == "list":
for idx, v in enumerate(value, 1):
pref = key + '.member.' + str(idx)
flat.update(flatten_json_request_body(
pref, v, spec[key]['member']))
elif node_type == 'map':
pref = key + ".member." + str(idx)
flat.update(flatten_json_request_body(pref, v, spec[key]["member"]))
elif node_type == "map":
for idx, (k, v) in enumerate(value.items(), 1):
pref = key + '.entry.' + str(idx)
flat.update(flatten_json_request_body(
pref + '.key', k, spec[key]['key']))
flat.update(flatten_json_request_body(
pref + '.value', v, spec[key]['value']))
pref = key + ".entry." + str(idx)
flat.update(
flatten_json_request_body(pref + ".key", k, spec[key]["key"])
)
flat.update(
flatten_json_request_body(pref + ".value", v, spec[key]["value"])
)
else:
flat.update(flatten_json_request_body(key, value, spec[key]))
if prefix:
prefix = prefix + '.'
prefix = prefix + "."
return dict((prefix + k, v) for k, v in flat.items())
@ -802,41 +943,40 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
od = OrderedDict()
for k, v in value.items():
if k.startswith('@'):
if k.startswith("@"):
continue
if k not in spec:
# this can happen when with an older version of
# botocore for which the node in XML template is not
# defined in service spec.
log.warning(
'Field %s is not defined by the botocore version in use', k)
log.warning("Field %s is not defined by the botocore version in use", k)
continue
if spec[k]['type'] == 'list':
if spec[k]["type"] == "list":
if v is None:
od[k] = []
elif len(spec[k]['member']) == 1:
if isinstance(v['member'], list):
od[k] = transform(v['member'], spec[k]['member'])
elif len(spec[k]["member"]) == 1:
if isinstance(v["member"], list):
od[k] = transform(v["member"], spec[k]["member"])
else:
od[k] = [transform(v['member'], spec[k]['member'])]
elif isinstance(v['member'], list):
od[k] = [transform(o, spec[k]['member'])
for o in v['member']]
elif isinstance(v['member'], OrderedDict):
od[k] = [transform(v['member'], spec[k]['member'])]
od[k] = [transform(v["member"], spec[k]["member"])]
elif isinstance(v["member"], list):
od[k] = [transform(o, spec[k]["member"]) for o in v["member"]]
elif isinstance(v["member"], OrderedDict):
od[k] = [transform(v["member"], spec[k]["member"])]
else:
raise ValueError('Malformatted input')
elif spec[k]['type'] == 'map':
raise ValueError("Malformatted input")
elif spec[k]["type"] == "map":
if v is None:
od[k] = {}
else:
items = ([v['entry']] if not isinstance(v['entry'], list) else
v['entry'])
items = (
[v["entry"]] if not isinstance(v["entry"], list) else v["entry"]
)
for item in items:
key = from_str(item['key'], spec[k]['key'])
val = from_str(item['value'], spec[k]['value'])
key = from_str(item["key"], spec[k]["key"])
val = from_str(item["value"], spec[k]["value"])
if k not in od:
od[k] = {}
od[k][key] = val
@ -850,7 +990,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
dic = xmltodict.parse(xml)
output_spec = service_spec.output_spec(operation)
try:
for k in (result_node or (operation + 'Response', operation + 'Result')):
for k in result_node or (operation + "Response", operation + "Result"):
dic = dic[k]
except KeyError:
return None

View File

@ -1,14 +1,13 @@
from __future__ import unicode_literals
from .responses import MotoAPIResponse
url_bases = [
"https?://motoapi.amazonaws.com"
]
url_bases = ["https?://motoapi.amazonaws.com"]
response_instance = MotoAPIResponse()
url_paths = {
'{0}/moto-api/$': response_instance.dashboard,
'{0}/moto-api/data.json': response_instance.model_data,
'{0}/moto-api/reset': response_instance.reset_response,
"{0}/moto-api/$": response_instance.dashboard,
"{0}/moto-api/data.json": response_instance.model_data,
"{0}/moto-api/reset": response_instance.reset_response,
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
}

View File

@ -8,6 +8,7 @@ import random
import re
import six
import string
from botocore.exceptions import ClientError
from six.moves.urllib.parse import urlparse
@ -15,9 +16,9 @@ REQUEST_ID_LONG = string.digits + string.ascii_uppercase
def camelcase_to_underscores(argument):
''' Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute'''
result = ''
""" Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute"""
result = ""
prev_char_title = True
if not argument:
return argument
@ -41,18 +42,18 @@ def camelcase_to_underscores(argument):
def underscores_to_camelcase(argument):
''' Converts a camelcase param like the_new_attribute to the equivalent
""" Converts a camelcase param like the_new_attribute to the equivalent
camelcase version like theNewAttribute. Note that the first letter is
NOT capitalized by this function '''
result = ''
NOT capitalized by this function """
result = ""
previous_was_underscore = False
for char in argument:
if char != '_':
if char != "_":
if previous_was_underscore:
result += char.upper()
else:
result += char
previous_was_underscore = char == '_'
previous_was_underscore = char == "_"
return result
@ -69,12 +70,18 @@ def method_names_from_class(clazz):
def get_random_hex(length=8):
chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f']
return ''.join(six.text_type(random.choice(chars)) for x in range(length))
chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"]
return "".join(six.text_type(random.choice(chars)) for x in range(length))
def get_random_message_id():
return '{0}-{1}-{2}-{3}-{4}'.format(get_random_hex(8), get_random_hex(4), get_random_hex(4), get_random_hex(4), get_random_hex(12))
return "{0}-{1}-{2}-{3}-{4}".format(
get_random_hex(8),
get_random_hex(4),
get_random_hex(4),
get_random_hex(4),
get_random_hex(12),
)
def convert_regex_to_flask_path(url_path):
@ -88,7 +95,7 @@ def convert_regex_to_flask_path(url_path):
match_name, match_pattern = reg.groups()
return '<regex("{0}"):{1}>'.format(match_pattern, match_name)
url_path = re.sub("\(\?P<(.*?)>(.*?)\)", caller, url_path)
url_path = re.sub(r"\(\?P<(.*?)>(.*?)\)", caller, url_path)
if url_path.endswith("/?"):
# Flask does own handling of trailing slashes
@ -97,7 +104,6 @@ def convert_regex_to_flask_path(url_path):
class convert_httpretty_response(object):
def __init__(self, callback):
self.callback = callback
@ -114,13 +120,12 @@ class convert_httpretty_response(object):
def __call__(self, request, url, headers, **kwargs):
result = self.callback(request, url, headers)
status, headers, response = result
if 'server' not in headers:
if "server" not in headers:
headers["server"] = "amazon.com"
return status, headers, response
class convert_flask_to_httpretty_response(object):
def __init__(self, callback):
self.callback = callback
@ -137,7 +142,10 @@ class convert_flask_to_httpretty_response(object):
def __call__(self, args=None, **kwargs):
from flask import request, Response
result = self.callback(request, request.url, {})
try:
result = self.callback(request, request.url, {})
except ClientError as exc:
result = 400, {}, exc.response["Error"]["Message"]
# result is a status, headers, response tuple
if len(result) == 3:
status, headers, content = result
@ -145,13 +153,12 @@ class convert_flask_to_httpretty_response(object):
status, headers, content = 200, {}, result
response = Response(response=content, status=status, headers=headers)
if request.method == "HEAD" and 'content-length' in headers:
response.headers['Content-Length'] = headers['content-length']
if request.method == "HEAD" and "content-length" in headers:
response.headers["Content-Length"] = headers["content-length"]
return response
class convert_flask_to_responses_response(object):
def __init__(self, callback):
self.callback = callback
@ -176,14 +183,14 @@ class convert_flask_to_responses_response(object):
def iso_8601_datetime_with_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + 'Z'
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def iso_8601_datetime_without_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z'
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
RFC1123 = '%a, %d %b %Y %H:%M:%S GMT'
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(datetime):
@ -212,16 +219,16 @@ def gen_amz_crc32(response, headerdict=None):
crc = str(binascii.crc32(response))
if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amz-crc32': crc})
headerdict.update({"x-amz-crc32": crc})
return crc
def gen_amzn_requestid_long(headerdict=None):
req_id = ''.join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
req_id = "".join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amzn-requestid': req_id})
headerdict.update({"x-amzn-requestid": req_id})
return req_id
@ -239,13 +246,13 @@ def amz_crc32(f):
else:
if len(response) == 2:
body, new_headers = response
status = new_headers.get('status', 200)
status = new_headers.get("status", 200)
else:
status, new_headers, body = response
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
headers["status"] = str(headers["status"])
try:
# Doesnt work on python2 for some odd unicode strings
@ -271,7 +278,7 @@ def amzn_request_id(f):
else:
if len(response) == 2:
body, new_headers = response
status = new_headers.get('status', 200)
status = new_headers.get("status", 200)
else:
status, new_headers, body = response
headers.update(new_headers)
@ -280,7 +287,7 @@ def amzn_request_id(f):
# Update request ID in XML
try:
body = re.sub(r'(?<=<RequestId>).*(?=<\/RequestId>)', request_id, body)
body = re.sub(r"(?<=<RequestId>).*(?=<\/RequestId>)", request_id, body)
except Exception: # Will just ignore if it cant work on bytes (which are str's on python2)
pass
@ -293,7 +300,31 @@ def path_url(url):
parsed_url = urlparse(url)
path = parsed_url.path
if not path:
path = '/'
path = "/"
if parsed_url.query:
path = path + '?' + parsed_url.query
path = path + "?" + parsed_url.query
return path
def py2_strip_unicode_keys(blob):
"""For Python 2 Only -- this will convert unicode keys in nested Dicts, Lists, and Sets to standard strings."""
if type(blob) == unicode: # noqa
return str(blob)
elif type(blob) == dict:
for key in list(blob.keys()):
value = blob.pop(key)
blob[str(key)] = py2_strip_unicode_keys(value)
elif type(blob) == list:
for i in range(0, len(blob)):
blob[i] = py2_strip_unicode_keys(blob[i])
elif type(blob) == set:
new_set = set()
for value in blob:
new_set.add(py2_strip_unicode_keys(value))
blob = new_set
return blob

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import datapipeline_backends
from ..core.models import base_decorator, deprecated_base_decorator
datapipeline_backend = datapipeline_backends['us-east-1']
datapipeline_backend = datapipeline_backends["us-east-1"]
mock_datapipeline = base_decorator(datapipeline_backends)
mock_datapipeline_deprecated = deprecated_base_decorator(datapipeline_backends)

View File

@ -1,92 +1,73 @@
from __future__ import unicode_literals
import datetime
import boto.datapipeline
from boto3 import Session
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys
class PipelineObject(BaseModel):
def __init__(self, object_id, name, fields):
self.object_id = object_id
self.name = name
self.fields = fields
def to_json(self):
return {
"fields": self.fields,
"id": self.object_id,
"name": self.name,
}
return {"fields": self.fields, "id": self.object_id, "name": self.name}
class Pipeline(BaseModel):
def __init__(self, name, unique_id, **kwargs):
self.name = name
self.unique_id = unique_id
self.description = kwargs.get('description', '')
self.description = kwargs.get("description", "")
self.pipeline_id = get_random_pipeline_id()
self.creation_time = datetime.datetime.utcnow()
self.objects = []
self.status = "PENDING"
self.tags = kwargs.get('tags', [])
self.tags = kwargs.get("tags", [])
@property
def physical_resource_id(self):
return self.pipeline_id
def to_meta_json(self):
return {
"id": self.pipeline_id,
"name": self.name,
}
return {"id": self.pipeline_id, "name": self.name}
def to_json(self):
return {
"description": self.description,
"fields": [{
"key": "@pipelineState",
"stringValue": self.status,
}, {
"key": "description",
"stringValue": self.description
}, {
"key": "name",
"stringValue": self.name
}, {
"key": "@creationTime",
"stringValue": datetime.datetime.strftime(self.creation_time, '%Y-%m-%dT%H-%M-%S'),
}, {
"key": "@id",
"stringValue": self.pipeline_id,
}, {
"key": "@sphere",
"stringValue": "PIPELINE"
}, {
"key": "@version",
"stringValue": "1"
}, {
"key": "@userId",
"stringValue": "924374875933"
}, {
"key": "@accountId",
"stringValue": "924374875933"
}, {
"key": "uniqueId",
"stringValue": self.unique_id
}],
"fields": [
{"key": "@pipelineState", "stringValue": self.status},
{"key": "description", "stringValue": self.description},
{"key": "name", "stringValue": self.name},
{
"key": "@creationTime",
"stringValue": datetime.datetime.strftime(
self.creation_time, "%Y-%m-%dT%H-%M-%S"
),
},
{"key": "@id", "stringValue": self.pipeline_id},
{"key": "@sphere", "stringValue": "PIPELINE"},
{"key": "@version", "stringValue": "1"},
{"key": "@userId", "stringValue": "924374875933"},
{"key": "@accountId", "stringValue": "924374875933"},
{"key": "uniqueId", "stringValue": self.unique_id},
],
"name": self.name,
"pipelineId": self.pipeline_id,
"tags": self.tags
"tags": self.tags,
}
def set_pipeline_objects(self, pipeline_objects):
self.objects = [
PipelineObject(pipeline_object['id'], pipeline_object[
'name'], pipeline_object['fields'])
PipelineObject(
pipeline_object["id"],
pipeline_object["name"],
pipeline_object["fields"],
)
for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects)
]
@ -94,15 +75,19 @@ class Pipeline(BaseModel):
self.status = "SCHEDULED"
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
datapipeline_backend = datapipeline_backends[region_name]
properties = cloudformation_json["Properties"]
cloudformation_unique_id = "cf-" + properties["Name"]
pipeline = datapipeline_backend.create_pipeline(
properties["Name"], cloudformation_unique_id)
properties["Name"], cloudformation_unique_id
)
datapipeline_backend.put_pipeline_definition(
pipeline.pipeline_id, properties["PipelineObjects"])
pipeline.pipeline_id, properties["PipelineObjects"]
)
if properties["Activate"]:
pipeline.activate()
@ -110,7 +95,6 @@ class Pipeline(BaseModel):
class DataPipelineBackend(BaseBackend):
def __init__(self):
self.pipelines = OrderedDict()
@ -123,8 +107,11 @@ class DataPipelineBackend(BaseBackend):
return self.pipelines.values()
def describe_pipelines(self, pipeline_ids):
pipelines = [pipeline for pipeline in self.pipelines.values(
) if pipeline.pipeline_id in pipeline_ids]
pipelines = [
pipeline
for pipeline in self.pipelines.values()
if pipeline.pipeline_id in pipeline_ids
]
return pipelines
def get_pipeline(self, pipeline_id):
@ -144,7 +131,8 @@ class DataPipelineBackend(BaseBackend):
def describe_objects(self, object_ids, pipeline_id):
pipeline = self.get_pipeline(pipeline_id)
pipeline_objects = [
pipeline_object for pipeline_object in pipeline.objects
pipeline_object
for pipeline_object in pipeline.objects
if pipeline_object.object_id in object_ids
]
return pipeline_objects
@ -155,5 +143,11 @@ class DataPipelineBackend(BaseBackend):
datapipeline_backends = {}
for region in boto.datapipeline.regions():
datapipeline_backends[region.name] = DataPipelineBackend()
for region in Session().get_available_regions("datapipeline"):
datapipeline_backends[region] = DataPipelineBackend()
for region in Session().get_available_regions(
"datapipeline", partition_name="aws-us-gov"
):
datapipeline_backends[region] = DataPipelineBackend()
for region in Session().get_available_regions("datapipeline", partition_name="aws-cn"):
datapipeline_backends[region] = DataPipelineBackend(region)

View File

@ -7,7 +7,6 @@ from .models import datapipeline_backends
class DataPipelineResponse(BaseResponse):
@property
def parameters(self):
# TODO this should really be moved to core/responses.py
@ -21,47 +20,47 @@ class DataPipelineResponse(BaseResponse):
return datapipeline_backends[self.region]
def create_pipeline(self):
name = self.parameters.get('name')
unique_id = self.parameters.get('uniqueId')
description = self.parameters.get('description', '')
tags = self.parameters.get('tags', [])
pipeline = self.datapipeline_backend.create_pipeline(name, unique_id, description=description, tags=tags)
return json.dumps({
"pipelineId": pipeline.pipeline_id,
})
name = self.parameters.get("name")
unique_id = self.parameters.get("uniqueId")
description = self.parameters.get("description", "")
tags = self.parameters.get("tags", [])
pipeline = self.datapipeline_backend.create_pipeline(
name, unique_id, description=description, tags=tags
)
return json.dumps({"pipelineId": pipeline.pipeline_id})
def list_pipelines(self):
pipelines = list(self.datapipeline_backend.list_pipelines())
pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines]
max_pipelines = 50
marker = self.parameters.get('marker')
marker = self.parameters.get("marker")
if marker:
start = pipeline_ids.index(marker) + 1
else:
start = 0
pipelines_resp = pipelines[start:start + max_pipelines]
pipelines_resp = pipelines[start : start + max_pipelines]
has_more_results = False
marker = None
if start + max_pipelines < len(pipeline_ids) - 1:
has_more_results = True
marker = pipelines_resp[-1].pipeline_id
return json.dumps({
"hasMoreResults": has_more_results,
"marker": marker,
"pipelineIdList": [
pipeline.to_meta_json() for pipeline in pipelines_resp
]
})
return json.dumps(
{
"hasMoreResults": has_more_results,
"marker": marker,
"pipelineIdList": [
pipeline.to_meta_json() for pipeline in pipelines_resp
],
}
)
def describe_pipelines(self):
pipeline_ids = self.parameters["pipelineIds"]
pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids)
return json.dumps({
"pipelineDescriptionList": [
pipeline.to_json() for pipeline in pipelines
]
})
return json.dumps(
{"pipelineDescriptionList": [pipeline.to_json() for pipeline in pipelines]}
)
def delete_pipeline(self):
pipeline_id = self.parameters["pipelineId"]
@ -72,31 +71,38 @@ class DataPipelineResponse(BaseResponse):
pipeline_id = self.parameters["pipelineId"]
pipeline_objects = self.parameters["pipelineObjects"]
self.datapipeline_backend.put_pipeline_definition(
pipeline_id, pipeline_objects)
self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects)
return json.dumps({"errored": False})
def get_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"]
pipeline_definition = self.datapipeline_backend.get_pipeline_definition(
pipeline_id)
return json.dumps({
"pipelineObjects": [pipeline_object.to_json() for pipeline_object in pipeline_definition]
})
pipeline_id
)
return json.dumps(
{
"pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_definition
]
}
)
def describe_objects(self):
pipeline_id = self.parameters["pipelineId"]
object_ids = self.parameters["objectIds"]
pipeline_objects = self.datapipeline_backend.describe_objects(
object_ids, pipeline_id)
return json.dumps({
"hasMoreResults": False,
"marker": None,
"pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_objects
]
})
object_ids, pipeline_id
)
return json.dumps(
{
"hasMoreResults": False,
"marker": None,
"pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_objects
],
}
)
def activate_pipeline(self):
pipeline_id = self.parameters["pipelineId"]

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import DataPipelineResponse
url_bases = [
"https?://datapipeline.(.+).amazonaws.com",
]
url_bases = ["https?://datapipeline.(.+).amazonaws.com"]
url_paths = {
'{0}/$': DataPipelineResponse.dispatch,
}
url_paths = {"{0}/$": DataPipelineResponse.dispatch}

View File

@ -1,5 +1,5 @@
import collections
import six
from moto.compat import collections_abc
from moto.core.utils import get_random_hex
@ -8,13 +8,15 @@ def get_random_pipeline_id():
def remove_capitalization_of_dict_keys(obj):
if isinstance(obj, collections.Mapping):
if isinstance(obj, collections_abc.Mapping):
result = obj.__class__()
for key, value in obj.items():
normalized_key = key[:1].lower() + key[1:]
result[normalized_key] = remove_capitalization_of_dict_keys(value)
return result
elif isinstance(obj, collections.Iterable) and not isinstance(obj, six.string_types):
elif isinstance(obj, collections_abc.Iterable) and not isinstance(
obj, six.string_types
):
result = obj.__class__()
for item in obj:
result += (remove_capitalization_of_dict_keys(item),)

View File

@ -0,0 +1,8 @@
from __future__ import unicode_literals
from ..core.models import base_decorator, deprecated_base_decorator
from .models import datasync_backends
datasync_backend = datasync_backends["us-east-1"]
mock_datasync = base_decorator(datasync_backends)
mock_datasync_deprecated = deprecated_base_decorator(datasync_backends)

View File

@ -0,0 +1,15 @@
from __future__ import unicode_literals
from moto.core.exceptions import JsonRESTError
class DataSyncClientError(JsonRESTError):
code = 400
class InvalidRequestException(DataSyncClientError):
def __init__(self, msg=None):
self.code = 400
super(InvalidRequestException, self).__init__(
"InvalidRequestException", msg or "The request is not valid."
)

235
moto/datasync/models.py Normal file
View File

@ -0,0 +1,235 @@
from boto3 import Session
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from .exceptions import InvalidRequestException
class Location(BaseModel):
def __init__(
self, location_uri, region_name=None, typ=None, metadata=None, arn_counter=0
):
self.uri = location_uri
self.region_name = region_name
self.metadata = metadata
self.typ = typ
# Generate ARN
self.arn = "arn:aws:datasync:{0}:111222333444:location/loc-{1}".format(
region_name, str(arn_counter).zfill(17)
)
class Task(BaseModel):
def __init__(
self,
source_location_arn,
destination_location_arn,
name,
region_name,
arn_counter=0,
metadata=None,
):
self.source_location_arn = source_location_arn
self.destination_location_arn = destination_location_arn
self.name = name
self.metadata = metadata
# For simplicity Tasks are either available or running
self.status = "AVAILABLE"
self.current_task_execution_arn = None
# Generate ARN
self.arn = "arn:aws:datasync:{0}:111222333444:task/task-{1}".format(
region_name, str(arn_counter).zfill(17)
)
class TaskExecution(BaseModel):
# For simplicity, task_execution can never fail
# Some documentation refers to this list:
# 'Status': 'QUEUED'|'LAUNCHING'|'PREPARING'|'TRANSFERRING'|'VERIFYING'|'SUCCESS'|'ERROR'
# Others refers to this list:
# INITIALIZING | PREPARING | TRANSFERRING | VERIFYING | SUCCESS/FAILURE
# Checking with AWS Support...
TASK_EXECUTION_INTERMEDIATE_STATES = (
"INITIALIZING",
# 'QUEUED', 'LAUNCHING',
"PREPARING",
"TRANSFERRING",
"VERIFYING",
)
TASK_EXECUTION_FAILURE_STATES = ("ERROR",)
TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",)
# Also COMPLETED state?
def __init__(self, task_arn, arn_counter=0):
self.task_arn = task_arn
self.arn = "{0}/execution/exec-{1}".format(task_arn, str(arn_counter).zfill(17))
self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[0]
# Simulate a task execution
def iterate_status(self):
if self.status in self.TASK_EXECUTION_FAILURE_STATES:
return
if self.status in self.TASK_EXECUTION_SUCCESS_STATES:
return
if self.status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
for i, status in enumerate(self.TASK_EXECUTION_INTERMEDIATE_STATES):
if status == self.status:
if i < len(self.TASK_EXECUTION_INTERMEDIATE_STATES) - 1:
self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[i + 1]
else:
self.status = self.TASK_EXECUTION_SUCCESS_STATES[0]
return
raise Exception(
"TaskExecution.iterate_status: Unknown status={0}".format(self.status)
)
def cancel(self):
if self.status not in self.TASK_EXECUTION_INTERMEDIATE_STATES:
raise InvalidRequestException(
"Sync task cannot be cancelled in its current status: {0}".format(
self.status
)
)
self.status = "ERROR"
class DataSyncBackend(BaseBackend):
def __init__(self, region_name):
self.region_name = region_name
# Always increase when new things are created
# This ensures uniqueness
self.arn_counter = 0
self.locations = OrderedDict()
self.tasks = OrderedDict()
self.task_executions = OrderedDict()
def reset(self):
region_name = self.region_name
self._reset_model_refs()
self.__dict__ = {}
self.__init__(region_name)
def create_location(self, location_uri, typ=None, metadata=None):
"""
# AWS DataSync allows for duplicate LocationUris
for arn, location in self.locations.items():
if location.uri == location_uri:
raise Exception('Location already exists')
"""
if not typ:
raise Exception("Location type must be specified")
self.arn_counter = self.arn_counter + 1
location = Location(
location_uri,
region_name=self.region_name,
arn_counter=self.arn_counter,
metadata=metadata,
typ=typ,
)
self.locations[location.arn] = location
return location.arn
def _get_location(self, location_arn, typ):
if location_arn not in self.locations:
raise InvalidRequestException(
"Location {0} is not found.".format(location_arn)
)
location = self.locations[location_arn]
if location.typ != typ:
raise InvalidRequestException(
"Invalid Location type: {0}".format(location.typ)
)
return location
def delete_location(self, location_arn):
if location_arn in self.locations:
del self.locations[location_arn]
else:
raise InvalidRequestException
def create_task(
self, source_location_arn, destination_location_arn, name, metadata=None
):
if source_location_arn not in self.locations:
raise InvalidRequestException(
"Location {0} not found.".format(source_location_arn)
)
if destination_location_arn not in self.locations:
raise InvalidRequestException(
"Location {0} not found.".format(destination_location_arn)
)
self.arn_counter = self.arn_counter + 1
task = Task(
source_location_arn,
destination_location_arn,
name,
region_name=self.region_name,
arn_counter=self.arn_counter,
metadata=metadata,
)
self.tasks[task.arn] = task
return task.arn
def _get_task(self, task_arn):
if task_arn in self.tasks:
return self.tasks[task_arn]
else:
raise InvalidRequestException
def update_task(self, task_arn, name, metadata):
if task_arn in self.tasks:
task = self.tasks[task_arn]
task.name = name
task.metadata = metadata
else:
raise InvalidRequestException(
"Sync task {0} is not found.".format(task_arn)
)
def delete_task(self, task_arn):
if task_arn in self.tasks:
del self.tasks[task_arn]
else:
raise InvalidRequestException
def start_task_execution(self, task_arn):
self.arn_counter = self.arn_counter + 1
if task_arn in self.tasks:
task = self.tasks[task_arn]
if task.status == "AVAILABLE":
task_execution = TaskExecution(task_arn, arn_counter=self.arn_counter)
self.task_executions[task_execution.arn] = task_execution
self.tasks[task_arn].current_task_execution_arn = task_execution.arn
self.tasks[task_arn].status = "RUNNING"
return task_execution.arn
raise InvalidRequestException("Invalid request.")
def _get_task_execution(self, task_execution_arn):
if task_execution_arn in self.task_executions:
return self.task_executions[task_execution_arn]
else:
raise InvalidRequestException
def cancel_task_execution(self, task_execution_arn):
if task_execution_arn in self.task_executions:
task_execution = self.task_executions[task_execution_arn]
task_execution.cancel()
task_arn = task_execution.task_arn
self.tasks[task_arn].current_task_execution_arn = None
self.tasks[task_arn].status = "AVAILABLE"
return
raise InvalidRequestException(
"Sync task {0} is not found.".format(task_execution_arn)
)
datasync_backends = {}
for region in Session().get_available_regions("datasync"):
datasync_backends[region] = DataSyncBackend(region)
for region in Session().get_available_regions("datasync", partition_name="aws-us-gov"):
datasync_backends[region] = DataSyncBackend(region)
for region in Session().get_available_regions("datasync", partition_name="aws-cn"):
datasync_backends[region] = DataSyncBackend(region)

162
moto/datasync/responses.py Normal file
View File

@ -0,0 +1,162 @@
import json
from moto.core.responses import BaseResponse
from .models import datasync_backends
class DataSyncResponse(BaseResponse):
@property
def datasync_backend(self):
return datasync_backends[self.region]
def list_locations(self):
locations = list()
for arn, location in self.datasync_backend.locations.items():
locations.append({"LocationArn": location.arn, "LocationUri": location.uri})
return json.dumps({"Locations": locations})
def _get_location(self, location_arn, typ):
return self.datasync_backend._get_location(location_arn, typ)
def create_location_s3(self):
# s3://bucket_name/folder/
s3_bucket_arn = self._get_param("S3BucketArn")
subdirectory = self._get_param("Subdirectory")
metadata = {"S3Config": self._get_param("S3Config")}
location_uri_elts = ["s3:/", s3_bucket_arn.split(":")[-1]]
if subdirectory:
location_uri_elts.append(subdirectory)
location_uri = "/".join(location_uri_elts)
arn = self.datasync_backend.create_location(
location_uri, metadata=metadata, typ="S3"
)
return json.dumps({"LocationArn": arn})
def describe_location_s3(self):
location_arn = self._get_param("LocationArn")
location = self._get_location(location_arn, typ="S3")
return json.dumps(
{
"LocationArn": location.arn,
"LocationUri": location.uri,
"S3Config": location.metadata["S3Config"],
}
)
def create_location_smb(self):
# smb://smb.share.fqdn/AWS_Test/
subdirectory = self._get_param("Subdirectory")
server_hostname = self._get_param("ServerHostname")
metadata = {
"AgentArns": self._get_param("AgentArns"),
"User": self._get_param("User"),
"Domain": self._get_param("Domain"),
"MountOptions": self._get_param("MountOptions"),
}
location_uri = "/".join(["smb:/", server_hostname, subdirectory])
arn = self.datasync_backend.create_location(
location_uri, metadata=metadata, typ="SMB"
)
return json.dumps({"LocationArn": arn})
def describe_location_smb(self):
location_arn = self._get_param("LocationArn")
location = self._get_location(location_arn, typ="SMB")
return json.dumps(
{
"LocationArn": location.arn,
"LocationUri": location.uri,
"AgentArns": location.metadata["AgentArns"],
"User": location.metadata["User"],
"Domain": location.metadata["Domain"],
"MountOptions": location.metadata["MountOptions"],
}
)
def delete_location(self):
location_arn = self._get_param("LocationArn")
self.datasync_backend.delete_location(location_arn)
return json.dumps({})
def create_task(self):
destination_location_arn = self._get_param("DestinationLocationArn")
source_location_arn = self._get_param("SourceLocationArn")
name = self._get_param("Name")
metadata = {
"CloudWatchLogGroupArn": self._get_param("CloudWatchLogGroupArn"),
"Options": self._get_param("Options"),
"Excludes": self._get_param("Excludes"),
"Tags": self._get_param("Tags"),
}
arn = self.datasync_backend.create_task(
source_location_arn, destination_location_arn, name, metadata=metadata
)
return json.dumps({"TaskArn": arn})
def update_task(self):
task_arn = self._get_param("TaskArn")
self.datasync_backend.update_task(
task_arn,
name=self._get_param("Name"),
metadata={
"CloudWatchLogGroupArn": self._get_param("CloudWatchLogGroupArn"),
"Options": self._get_param("Options"),
"Excludes": self._get_param("Excludes"),
"Tags": self._get_param("Tags"),
},
)
return json.dumps({})
def list_tasks(self):
tasks = list()
for arn, task in self.datasync_backend.tasks.items():
tasks.append(
{"Name": task.name, "Status": task.status, "TaskArn": task.arn}
)
return json.dumps({"Tasks": tasks})
def delete_task(self):
task_arn = self._get_param("TaskArn")
self.datasync_backend.delete_task(task_arn)
return json.dumps({})
def describe_task(self):
task_arn = self._get_param("TaskArn")
task = self.datasync_backend._get_task(task_arn)
return json.dumps(
{
"TaskArn": task.arn,
"Status": task.status,
"Name": task.name,
"CurrentTaskExecutionArn": task.current_task_execution_arn,
"SourceLocationArn": task.source_location_arn,
"DestinationLocationArn": task.destination_location_arn,
"CloudWatchLogGroupArn": task.metadata["CloudWatchLogGroupArn"],
"Options": task.metadata["Options"],
"Excludes": task.metadata["Excludes"],
}
)
def start_task_execution(self):
task_arn = self._get_param("TaskArn")
arn = self.datasync_backend.start_task_execution(task_arn)
return json.dumps({"TaskExecutionArn": arn})
def cancel_task_execution(self):
task_execution_arn = self._get_param("TaskExecutionArn")
self.datasync_backend.cancel_task_execution(task_execution_arn)
return json.dumps({})
def describe_task_execution(self):
task_execution_arn = self._get_param("TaskExecutionArn")
task_execution = self.datasync_backend._get_task_execution(task_execution_arn)
result = json.dumps(
{"TaskExecutionArn": task_execution.arn, "Status": task_execution.status}
)
if task_execution.status == "SUCCESS":
self.datasync_backend.tasks[task_execution.task_arn].status = "AVAILABLE"
# Simulate task being executed
task_execution.iterate_status()
return result

Some files were not shown because too many files have changed in this diff Show More