Merge pull request #1 from spulec/master

Align
This commit is contained in:
xnegativx 2020-12-03 10:38:30 +01:00 committed by GitHub
commit ac0ffab4d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
522 changed files with 63319 additions and 10187 deletions

View File

@ -3,6 +3,7 @@
exclude_lines =
if __name__ == .__main__.:
raise NotImplemented.
return NotImplemented
def __repr__
[run]

2
.gitignore vendored
View File

@ -22,3 +22,5 @@ tests/file.tmp
.eggs/
.mypy_cache/
*.tmp
.venv/
htmlcov/

View File

@ -1,4 +1,4 @@
dist: bionic
dist: focal
language: python
services:
- docker
@ -26,11 +26,13 @@ install:
fi
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh &
fi
travis_retry pip install -r requirements-dev.txt
travis_retry pip install docker>=2.5.1
travis_retry pip install boto==2.45.0
travis_retry pip install boto3
travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt
travis_retry pip install coverage==4.5.4
if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py

View File

@ -1,6 +1,224 @@
Moto Changelog
===================
Unreleased
-----
* Reduced dependency overhead.
It is now possible to install dependencies for only specific services using:
pip install moto[service1,service1].
See the README for more information.
1.3.16
-----
Full list of PRs merged in this release:
https://github.com/spulec/moto/pulls?q=is%3Apr+is%3Aclosed+merged%3A2019-11-14..2020-09-07
General Changes:
* The scaffold.py-script has been fixed to make it easier to scaffold new services.
See the README for an introduction.
New Services:
* Application Autoscaling
* Code Commit
* Code Pipeline
* Elastic Beanstalk
* Kinesis Video
* Kinesis Video Archived Media
* Managed BlockChain
* Resource Access Manager (ram)
* Sagemaker
New Methods:
* Athena:
* create_named_query
* get_named_query
* get_work_group
* start_query_execution
* stop_query_execution
* API Gateway:
* create_authorizer
* create_domain_name
* create_model
* delete_authorizer
* get_authorizer
* get_authorizers
* get_domain_name
* get_domain_names
* get_model
* get_models
* update_authorizer
* Autoscaling:
* enter_standby
* exit_standby
* terminate_instance_in_auto_scaling_group
* CloudFormation:
* get_template_summary
* CloudWatch:
* describe_alarms_for_metric
* get_metric_data
* CloudWatch Logs:
* delete_subscription_filter
* describe_subscription_filters
* put_subscription_filter
* Cognito IDP:
* associate_software_token
* create_resource_server
* confirm_sign_up
* initiate_auth
* set_user_mfa_preference
* sign_up
* verify_software_token
* DynamoDB:
* describe_continuous_backups
* transact_get_items
* transact_write_items
* update_continuous_backups
* EC2:
* create_vpc_endpoint
* describe_vpc_classic_link
* describe_vpc_classic_link_dns_support
* describe_vpc_endpoint_services
* disable_vpc_classic_link
* disable_vpc_classic_link_dns_support
* enable_vpc_classic_link
* enable_vpc_classic_link_dns_support
* register_image
* ECS:
* create_task_set
* delete_task_set
* describe_task_set
* update_service_primary_task_set
* update_task_set
* Events:
* delete_event_bus
* create_event_bus
* list_event_buses
* list_tags_for_resource
* tag_resource
* untag_resource
* Glue:
* get_databases
* IAM:
* delete_group
* delete_instance_profile
* delete_ssh_public_key
* get_account_summary
* get_ssh_public_key
* list_user_tags
* list_ssh_public_keys
* update_ssh_public_key
* upload_ssh_public_key
* IOT:
* cancel_job
* cancel_job_execution
* create_policy_version
* delete_job
* delete_job_execution
* describe_endpoint
* describe_job_execution
* delete_policy_version
* get_policy_version
* get_job_document
* list_attached_policies
* list_job_executions_for_job
* list_job_executions_for_thing
* list_jobs
* list_policy_versions
* set_default_policy_version
* register_certificate_without_ca
* KMS:
* untag_resource
* Lambda:
* delete_function_concurrency
* get_function_concurrency
* put_function_concurrency
* Organisations:
* describe_create_account_status
* deregister_delegated_administrator
* disable_policy_type
* enable_policy_type
* list_delegated_administrators
* list_delegated_services_for_account
* list_tags_for_resource
* register_delegated_administrator
* tag_resource
* untag_resource
* update_organizational_unit
* S3:
* delete_bucket_encryption
* delete_public_access_block
* get_bucket_encryption
* get_public_access_block
* put_bucket_encryption
* put_public_access_block
* S3 Control:
* delete_public_access_block
* get_public_access_block
* put_public_access_block
* SecretsManager:
* get_resource_policy
* update_secret
* SES:
* create_configuration_set
* create_configuration_set_event_destination
* create_receipt_rule_set
* create_receipt_rule
* create_template
* get_template
* get_send_statistics
* list_templates
* STS:
* assume_role_with_saml
* SSM:
* create_documen
* delete_document
* describe_document
* get_document
* list_documents
* update_document
* update_document_default_version
* SWF:
* undeprecate_activity_type
* undeprecate_domain
* undeprecate_workflow_type
General Updates:
* API Gateway - create_rest_api now supports policy-parameter
* Autoscaling - describe_auto_scaling_instances now supports InstanceIds-parameter
* AutoScalingGroups - now support launch templates
* CF - Now supports DependsOn-configuration
* CF - Now supports FN::Transform AWS::Include mapping
* CF - Now supports update and deletion of Lambdas
* CF - Now supports creation, update and deletion of EventBus (Events)
* CF - Now supports update of Rules (Events)
* CF - Now supports creation, update and deletion of EventSourceMappings (AWS Lambda)
* CF - Now supports update and deletion of Kinesis Streams
* CF - Now supports creation of DynamoDB streams
* CF - Now supports deletion of DynamoDB tables
* CF - list_stacks now supports the status_filter-parameter
* Cognito IDP - list_users now supports filter-parameter
* DynamoDB - GSI/LSI's now support ProjectionType=KEYS_ONLY
* EC2 - create_route now supports the NetworkInterfaceId-parameter
* EC2 - describe_instances now supports additional filters (owner-id)
* EC2 - describe_instance_status now supports additional filters (instance-state-name, instance-state-code)
* EC2 - describe_nat_gateways now supports additional filters (nat-gateway-id, vpc-id, subnet-id, state)
* EC2 - describe_vpn_gateways now supports additional filters (attachment.vpc_id, attachment.state, vpn-gateway-id, type)
* IAM - list_users now supports path_prefix-parameter
* IOT - list_thing_groups now supports parent_group, name_prefix_filter, recursive-parameters
* S3 - delete_objects now supports deletion of specific VersionIds
* SecretsManager - list_secrets now supports filters-parameter
* SFN - start_execution now receives and validates input
* SNS - Now supports sending a message directly to a phone number
* SQS - MessageAttributes now support labeled DataTypes
1.3.15
-----
This release broke dependency management for a lot of services - please upgrade to 1.3.16.
1.3.14
-----

View File

@ -23,8 +23,8 @@ However, this will only work on resource types that have this enabled.
### Current enabled resource types:
1. S3
1. S3 (all)
1. IAM (Role, Policy)
## Developer Guide
@ -53,15 +53,14 @@ An example of the above is implemented for S3. You can see that by looking at:
1. `moto/s3/config.py`
1. `moto/config/models.py`
As well as the corresponding unit tests in:
### Testing
For each resource type, you will need to test write tests for a few separate areas:
1. `tests/s3/test_s3.py`
1. `tests/config/test_config.py`
- Test the backend queries to ensure discovered resources come back (ie for `IAM::Policy`, write `tests.tests_iam.test_policy_list_config_discovered_resources`). For writing these tests, you must not make use of `boto` to create resources. You will need to use the backend model methods to provision the resources. This is to make tests compatible with the moto server. You must make tests for the resource type to test listing and object fetching.
Note for unit testing, you will want to add a test to ensure that you can query all the resources effectively. For testing this feature,
the unit tests for the `ConfigQueryModel` will not make use of `boto` to create resources, such as S3 buckets. You will need to use the
backend model methods to provision the resources. This is to make tests compatible with the moto server. You should absolutely make tests
in the resource type to test listing and object fetching.
- Test the config dict for all scenarios (ie for `IAM::Policy`, write `tests.tests_iam.test_policy_config_dict`). For writing this test, you'll need to create resources in the same way as the first test (without using `boto`), in every meaningful configuration that would produce a different config dict. Then, query the backend and ensure each of the dicts are as you expect.
- Test that everything works end to end with the `boto` clients. (ie for `IAM::Policy`, write `tests.tests_iam.test_policy_config_client`). The main two items to test will be the `boto.client('config').list_discovered_resources()`, `boto.client('config').list_aggregate_discovered_resources()`, `moto.client('config').batch_get_resource_config()`, and `moto.client('config').batch_aggregate_get_resource_config()`. This test doesn't have to be super thorough, but it basically tests that the front end and backend logic all works together and returns correct resources. Beware the aggregate methods all have capital first letters (ie `Limit`), while non-aggregate methods have lowercase first letters (ie `limit`)
### Listing
S3 is currently the model implementation, but it also odd in that S3 is a global resource type with regional resource residency.

View File

@ -1,29 +1,96 @@
### Contributing code
# Contributing code
Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project.
## Running the tests locally
Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
Moto has a [Makefile](./Makefile) which has some helpful commands for getting set up.
You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
## Is there a missing feature?
*NB. On first run, some tests might take a while to execute, especially the Lambda ones, because they may need to download a Docker image before they can execute.*
## Linting
Run `make lint` or `black --check moto tests` to verify whether your code confirms to the guidelines.
## Getting to grips with the codebase
Moto maintains a list of [good first issues](https://github.com/spulec/moto/contribute) which you may want to look at before
implementing a whole new endpoint.
## Missing features
Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services.
How to teach Moto to support a new AWS endpoint:
* Create an issue describing what's missing. This is where we'll all talk about the new addition and help you get it done.
* Search for an existing [issue](https://github.com/spulec/moto/issues) that matches what you want to achieve.
* If one doesn't already exist, create a new issue describing what's missing. This is where we'll all talk about the new addition and help you get it done.
* Create a [pull request](https://help.github.com/articles/using-pull-requests/) and mention the issue # in the PR description.
* Try to add a failing test case. For example, if you're trying to implement `boto3.client('acm').import_certificate()` you'll want to add a new method called `def test_import_certificate` to `tests/test_acm/test_acm.py`.
* If you can also implement the code that gets that test passing that's great. If not, just ask the community for a hand and somebody will assist you.
* Implementing the feature itself can be done by creating a method called `import_certificate` in `moto/acm/responses.py`. It's considered good practice to deal with input/output formatting and validation in `responses.py`, and create a method `import_certificate` in `moto/acm/models.py` that handles the actual import logic.
* If you can also implement the code that gets that test passing then great! If not, just ask the community for a hand and somebody will assist you.
# Maintainers
## Before pushing changes to GitHub
## Releasing a new version of Moto
1. Run `black moto/ tests/` over your code to ensure that it is properly formatted
1. Run `make test` to ensure your tests are passing
You'll need a PyPi account and a Dockerhub account to release Moto. After we release a new PyPi package we build and push the [motoserver/moto](https://hub.docker.com/r/motoserver/moto/) Docker image.
## Python versions
moto currently supports both Python 2 and 3, so make sure your tests pass against both major versions of Python.
## Missing services
Implementing a new service from scratch is more work, but still quite straightforward. All the code that intercepts network requests to `*.amazonaws.com` is already handled for you in `moto/core` - all that's necessary for new services to be recognized is to create a new decorator and determine which URLs should be intercepted.
See this PR for an example of what's involved in creating a new service: https://github.com/spulec/moto/pull/2409/files
Note the `urls.py` that redirects all incoming URL requests to a generic `dispatch` method, which in turn will call the appropriate method in `responses.py`.
If you want more control over incoming requests or their bodies, it is possible to redirect specific requests to a custom method. See this PR for an example: https://github.com/spulec/moto/pull/2957/files
### Generating template code of services.
By using `scripts/scaffold.py`, you can automatically generate template code of new services and new method of existing service. The script looks up API specification of given boto3 method and adds necessary codes includng request parameters and response parameters. In some cases, it fails to generate codes.
Please try out by runninig `python scripts/scaffold.py`
```bash
$ python scripts/scaffold.py
Select service: codedeploy
==Current Implementation Status==
[ ] add_tags_to_on_premises_instances
...
[ ] create_deployment
...[
[ ] update_deployment_group
=================================
Select Operation: create_deployment
Initializing service codedeploy
creating moto/codedeploy
creating moto/codedeploy/models.py
creating moto/codedeploy/exceptions.py
creating moto/codedeploy/__init__.py
creating moto/codedeploy/responses.py
creating moto/codedeploy/urls.py
creating tests/test_codedeploy
creating tests/test_codedeploy/test_server.py
creating tests/test_codedeploy/test_codedeploy.py
inserting code moto/codedeploy/responses.py
inserting code moto/codedeploy/models.py
You will still need to add the mock into "__init__.py"
```
## Maintainers
### Releasing a new version of Moto
You'll need a PyPi account and a DockerHub account to release Moto. After we release a new PyPi package we build and push the [motoserver/moto](https://hub.docker.com/r/motoserver/moto/) Docker image.
* First, `scripts/bump_version` modifies the version and opens a PR
* Then, merge the new pull request
* Finally, generate and ship the new artifacts with `make publish`

View File

@ -1,22 +1,12 @@
FROM alpine:3.6
RUN apk add --no-cache --update \
gcc \
musl-dev \
python3-dev \
libffi-dev \
openssl-dev \
python3
FROM python:3.7-slim
ADD . /moto/
ENV PYTHONUNBUFFERED 1
WORKDIR /moto/
RUN python3 -m ensurepip && \
rm -r /usr/lib/python*/ensurepip && \
pip3 --no-cache-dir install --upgrade pip setuptools && \
RUN pip3 --no-cache-dir install --upgrade pip setuptools && \
pip3 --no-cache-dir install ".[server]"
ENTRYPOINT ["/usr/bin/moto_server", "-H", "0.0.0.0"]
ENTRYPOINT ["/usr/local/bin/moto_server", "-H", "0.0.0.0"]
EXPOSE 5000

File diff suppressed because it is too large Load Diff

View File

@ -3,5 +3,6 @@ include requirements.txt requirements-dev.txt tox.ini
include moto/ec2/resources/instance_types.json
include moto/ec2/resources/amis.json
include moto/cognitoidp/resources/*.json
include moto/dynamodb2/parsing/reserved_keywords.txt
recursive-include moto/templates *
recursive-include tests *

View File

@ -3,7 +3,11 @@ SHELL := /bin/bash
ifeq ($(TEST_SERVER_MODE), true)
# exclude test_iot and test_iotdata for now
# because authentication of iot is very complicated
TEST_EXCLUDE := --exclude='test_iot.*'
# exclude test_kinesisvideoarchivedmedia
# because testing with moto_server is difficult with data-endpoint
TEST_EXCLUDE := -k 'not (test_iot or test_kinesisvideoarchivedmedia)'
else
TEST_EXCLUDE :=
endif
@ -19,13 +23,13 @@ lint:
test-only:
rm -f .coverage
rm -rf cover
@nosetests -sv --with-coverage --cover-html ./tests/ $(TEST_EXCLUDE)
@pytest -sv --cov=moto --cov-report html ./tests/ $(TEST_EXCLUDE)
test: lint test-only
test_server:
@TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/
@TEST_SERVER_MODE=true pytest -sv --cov=moto --cov-report html ./tests/
aws_managed_policies:
scripts/update_managed_policies.py
@ -35,7 +39,7 @@ upload_pypi_artifact:
twine upload dist/*
push_dockerhub_image:
docker build -t motoserver/moto .
docker build -t motoserver/moto . --tag moto:`python setup.py --version`
docker push motoserver/moto
tag_github_release:
@ -53,3 +57,6 @@ implementation_coverage:
scaffold:
@pip install -r requirements-dev.txt > /dev/null
exec python scripts/scaffold.py
int_test:
@./scripts/int_test.sh

173
README.md
View File

@ -9,6 +9,25 @@
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
## Install
To install moto for a specific service:
```console
$ pip install moto[ec2,s3]
```
This will install Moto, and the dependencies required for that specific service.
If you don't care about the number of dependencies, or if you want to mock many AWS services:
```console
$ pip install moto[all]
```
Not all services might be covered, in which case you might see a warning:
`moto 1.3.16 does not provide the extra 'service'`.
You can ignore the warning, or simply install moto as is:
```console
$ pip install moto
```
## In a nutshell
Moto is a library that allows your tests to easily mock out AWS Services.
@ -57,98 +76,58 @@ With the decorator wrapping the test, all the calls to s3 are automatically mock
It gets even better! Moto isn't just for Python code and it isn't just for S3. Look at the [standalone server mode](https://github.com/spulec/moto#stand-alone-server-mode) for more information about running Moto with other languages. Here's the status of the other AWS services implemented:
```gherkin
|-------------------------------------------------------------------------------------|
| Service Name | Decorator | Development Status |
|-------------------------------------------------------------------------------------|
| ACM | @mock_acm | all endpoints done |
|-------------------------------------------------------------------------------------|
| API Gateway | @mock_apigateway | core endpoints done |
|-------------------------------------------------------------------------------------|
| Autoscaling | @mock_autoscaling | core endpoints done |
|-------------------------------------------------------------------------------------|
| Cloudformation | @mock_cloudformation | core endpoints done |
|-------------------------------------------------------------------------------------|
| Cloudwatch | @mock_cloudwatch | basic endpoints done |
|-------------------------------------------------------------------------------------|
| CloudwatchEvents | @mock_events | all endpoints done |
|-------------------------------------------------------------------------------------|
| Cognito Identity | @mock_cognitoidentity | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Cognito Identity Provider | @mock_cognitoidp | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Config | @mock_config | basic endpoints done |
| | | core endpoints done |
|-------------------------------------------------------------------------------------|
| Data Pipeline | @mock_datapipeline | basic endpoints done |
|-------------------------------------------------------------------------------------|
| DynamoDB | @mock_dynamodb | core endpoints done |
| DynamoDB2 | @mock_dynamodb2 | all endpoints + partial indexes |
|-------------------------------------------------------------------------------------|
| EC2 | @mock_ec2 | core endpoints done |
| - AMI | | core endpoints done |
| - EBS | | core endpoints done |
| - Instances | | all endpoints done |
| - Security Groups | | core endpoints done |
| - Tags | | all endpoints done |
|-------------------------------------------------------------------------------------|
| ECR | @mock_ecr | basic endpoints done |
|-------------------------------------------------------------------------------------|
| ECS | @mock_ecs | basic endpoints done |
|-------------------------------------------------------------------------------------|
| ELB | @mock_elb | core endpoints done |
|-------------------------------------------------------------------------------------|
| ELBv2 | @mock_elbv2 | all endpoints done |
|-------------------------------------------------------------------------------------|
| EMR | @mock_emr | core endpoints done |
|-------------------------------------------------------------------------------------|
| Glacier | @mock_glacier | core endpoints done |
|-------------------------------------------------------------------------------------|
| IAM | @mock_iam | core endpoints done |
|-------------------------------------------------------------------------------------|
| IoT | @mock_iot | core endpoints done |
| | @mock_iotdata | core endpoints done |
|-------------------------------------------------------------------------------------|
| Kinesis | @mock_kinesis | core endpoints done |
|-------------------------------------------------------------------------------------|
| KMS | @mock_kms | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Lambda | @mock_lambda | basic endpoints done, requires |
| | | docker |
|-------------------------------------------------------------------------------------|
| Logs | @mock_logs | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Organizations | @mock_organizations | some core endpoints done |
|-------------------------------------------------------------------------------------|
| Polly | @mock_polly | all endpoints done |
|-------------------------------------------------------------------------------------|
| RDS | @mock_rds | core endpoints done |
|-------------------------------------------------------------------------------------|
| RDS2 | @mock_rds2 | core endpoints done |
|-------------------------------------------------------------------------------------|
| Redshift | @mock_redshift | core endpoints done |
|-------------------------------------------------------------------------------------|
| Route53 | @mock_route53 | core endpoints done |
|-------------------------------------------------------------------------------------|
| S3 | @mock_s3 | core endpoints done |
|-------------------------------------------------------------------------------------|
| SecretsManager | @mock_secretsmanager | basic endpoints done |
|-------------------------------------------------------------------------------------|
| SES | @mock_ses | all endpoints done |
|-------------------------------------------------------------------------------------|
| SNS | @mock_sns | all endpoints done |
|-------------------------------------------------------------------------------------|
| SQS | @mock_sqs | core endpoints done |
|-------------------------------------------------------------------------------------|
| SSM | @mock_ssm | core endpoints done |
|-------------------------------------------------------------------------------------|
| STS | @mock_sts | core endpoints done |
|-------------------------------------------------------------------------------------|
| SWF | @mock_swf | basic endpoints done |
|-------------------------------------------------------------------------------------|
| X-Ray | @mock_xray | all endpoints done |
|-------------------------------------------------------------------------------------|
```
| Service Name | Decorator | Development Status | Comment |
|---------------------------|-----------------------|---------------------------------|-----------------------------|
| ACM | @mock_acm | all endpoints done | |
| API Gateway | @mock_apigateway | core endpoints done | |
| Application Autoscaling | @mock_applicationautoscaling | basic endpoints done | |
| Athena | @mock_athena | core endpoints done | |
| Autoscaling | @mock_autoscaling | core endpoints done | |
| Cloudformation | @mock_cloudformation | core endpoints done | |
| Cloudwatch | @mock_cloudwatch | basic endpoints done | |
| CloudwatchEvents | @mock_events | all endpoints done | |
| Cognito Identity | @mock_cognitoidentity | basic endpoints done | |
| Cognito Identity Provider | @mock_cognitoidp | basic endpoints done | |
| Config | @mock_config | basic + core endpoints done | |
| Data Pipeline | @mock_datapipeline | basic endpoints done | |
| DynamoDB | @mock_dynamodb | core endpoints done | API 20111205. Deprecated. |
| DynamoDB2 | @mock_dynamodb2 | all endpoints + partial indexes | API 20120810 (Latest) |
| EC2 | @mock_ec2 | core endpoints done | |
| - AMI | | core endpoints done | |
| - EBS | | core endpoints done | |
| - Instances | | all endpoints done | |
| - Security Groups | | core endpoints done | |
| - Tags | | all endpoints done | |
| ECR | @mock_ecr | basic endpoints done | |
| ECS | @mock_ecs | basic endpoints done | |
| ELB | @mock_elb | core endpoints done | |
| ELBv2 | @mock_elbv2 | all endpoints done | |
| EMR | @mock_emr | core endpoints done | |
| Forecast | @mock_forecast | some core endpoints done | |
| Glacier | @mock_glacier | core endpoints done | |
| Glue | @mock_glue | core endpoints done | |
| IAM | @mock_iam | core endpoints done | |
| IoT | @mock_iot | core endpoints done | |
| IoT data | @mock_iotdata | core endpoints done | |
| Kinesis | @mock_kinesis | core endpoints done | |
| KMS | @mock_kms | basic endpoints done | |
| Lambda | @mock_lambda | basic endpoints done, requires docker | |
| Logs | @mock_logs | basic endpoints done | |
| Organizations | @mock_organizations | some core endpoints done | |
| Polly | @mock_polly | all endpoints done | |
| RDS | @mock_rds | core endpoints done | |
| RDS2 | @mock_rds2 | core endpoints done | |
| Redshift | @mock_redshift | core endpoints done | |
| Route53 | @mock_route53 | core endpoints done | |
| S3 | @mock_s3 | core endpoints done | |
| SecretsManager | @mock_secretsmanager | basic endpoints done | |
| SES | @mock_ses | all endpoints done | |
| SNS | @mock_sns | all endpoints done | |
| SQS | @mock_sqs | core endpoints done | |
| SSM | @mock_ssm | core endpoints done | |
| STS | @mock_sts | core endpoints done | |
| SWF | @mock_swf | basic endpoints done | |
| X-Ray | @mock_xray | all endpoints done | |
For a full list of endpoint [implementation coverage](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md)
@ -450,12 +429,14 @@ boto3.resource(
)
```
## Install
### Caveats
The standalone server has some caveats with some services. The following services
require that you update your hosts file for your code to work properly:
1. `s3-control`
```console
$ pip install moto
```
For the above services, this is required because the hostname is in the form of `AWS_ACCOUNT_ID.localhost`.
As a result, you need to add that entry to your host file for your tests to function properly.
## Releases

View File

@ -20,12 +20,12 @@ import shlex
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#sys.path.insert(0, os.path.abspath('.'))
# sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0'
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
@ -33,32 +33,34 @@ import shlex
extensions = []
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = ".rst"
# The encoding of source files.
#source_encoding = 'utf-8-sig'
# source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# General information about the project.
project = 'Moto'
copyright = '2015, Steve Pulec'
author = 'Steve Pulec'
project = "Moto"
copyright = "2015, Steve Pulec"
author = "Steve Pulec"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = '0.4.10'
import moto
version = moto.__version__
# The full version, including alpha/beta/rc tags.
release = '0.4.10'
release = moto.__version__
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@ -69,37 +71,37 @@ language = None
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
#today = ''
# today = ''
# Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y'
# today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build']
exclude_patterns = ["_build"]
# The reST default role (used for this markup: `text`) to use for all
# documents.
#default_role = None
# default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True
# add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
#add_module_names = True
# add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
#show_authors = False
# show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
# modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built documents.
#keep_warnings = False
# keep_warnings = False
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
@ -109,156 +111,149 @@ todo_include_todos = False
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#html_theme_options = {}
# html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
#html_theme_path = []
# html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation".
#html_title = None
# html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
#html_short_title = None
# html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
#html_logo = None
# html_logo = None
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
#html_favicon = None
# html_favicon = None
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
# directly to the root of the documentation.
#html_extra_path = []
# html_extra_path = []
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
#html_last_updated_fmt = '%b %d, %Y'
# html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
#html_use_smartypants = True
# html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
#html_sidebars = {}
# html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to
# template names.
#html_additional_pages = {}
# html_additional_pages = {}
# If false, no module index is generated.
#html_domain_indices = True
# html_domain_indices = True
# If false, no index is generated.
#html_use_index = True
# html_use_index = True
# If true, the index is split into individual pages for each letter.
#html_split_index = False
# html_split_index = False
# If true, links to the reST sources are added to the pages.
#html_show_sourcelink = True
# html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
#html_show_sphinx = True
# html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
#html_show_copyright = True
# html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a <link> tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
#html_use_opensearch = ''
# html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
#html_file_suffix = None
# html_file_suffix = None
# Language to be used for generating the HTML full-text search index.
# Sphinx supports the following languages:
# 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja'
# 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr'
#html_search_language = 'en'
# html_search_language = 'en'
# A dictionary with options for the search language support, empty by default.
# Now only 'ja' uses this config value
#html_search_options = {'type': 'default'}
# html_search_options = {'type': 'default'}
# The name of a javascript file (relative to the configuration directory) that
# implements a search results scorer. If empty, the default will be used.
#html_search_scorer = 'scorer.js'
# html_search_scorer = 'scorer.js'
# Output file base name for HTML help builder.
htmlhelp_basename = 'Motodoc'
htmlhelp_basename = "Motodoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'Moto.tex', 'Moto Documentation',
'Steve Pulec', 'manual'),
(master_doc, "Moto.tex", "Moto Documentation", "Steve Pulec", "manual"),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
#latex_logo = None
# latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
#latex_use_parts = False
# latex_use_parts = False
# If true, show page references after internal links.
#latex_show_pagerefs = False
# latex_show_pagerefs = False
# If true, show URL addresses after external links.
#latex_show_urls = False
# latex_show_urls = False
# Documents to append as an appendix to all manuals.
#latex_appendices = []
# latex_appendices = []
# If false, no module index is generated.
#latex_domain_indices = True
# latex_domain_indices = True
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'moto', 'Moto Documentation',
[author], 1)
]
man_pages = [(master_doc, "moto", "Moto Documentation", [author], 1)]
# If true, show URL addresses after external links.
#man_show_urls = False
# man_show_urls = False
# -- Options for Texinfo output -------------------------------------------
@ -267,19 +262,25 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'Moto', 'Moto Documentation',
author, 'Moto', 'One line description of project.',
'Miscellaneous'),
(
master_doc,
"Moto",
"Moto Documentation",
author,
"Moto",
"One line description of project.",
"Miscellaneous",
),
]
# Documents to append as an appendix to all manuals.
#texinfo_appendices = []
# texinfo_appendices = []
# If false, no module index is generated.
#texinfo_domain_indices = True
# texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'.
#texinfo_show_urls = 'footnote'
# texinfo_show_urls = 'footnote'
# If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False
# texinfo_no_detailmenu = False

View File

@ -24,8 +24,7 @@ For example, we have the following code we want to test:
.. sourcecode:: python
import boto
from boto.s3.key import Key
import boto3
class MyModel(object):
def __init__(self, name, value):
@ -33,11 +32,8 @@ For example, we have the following code we want to test:
self.value = value
def save(self):
conn = boto.connect_s3()
bucket = conn.get_bucket('mybucket')
k = Key(bucket)
k.key = self.name
k.set_contents_from_string(self.value)
s3 = boto3.client('s3', region_name='us-east-1')
s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value)
There are several ways to do this, but you should keep in mind that Moto creates a full, blank environment.
@ -48,20 +44,23 @@ With a decorator wrapping, all the calls to S3 are automatically mocked out.
.. sourcecode:: python
import boto
import boto3
from moto import mock_s3
from mymodule import MyModel
@mock_s3
def test_my_model_save():
conn = boto.connect_s3()
conn = boto3.resource('s3', region_name='us-east-1')
# We need to create the bucket since this is all in Moto's 'virtual' AWS account
conn.create_bucket('mybucket')
conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome')
model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome'
body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
Context manager
~~~~~~~~~~~~~~~
@ -72,13 +71,16 @@ Same as the Decorator, every call inside the ``with`` statement is mocked out.
def test_my_model_save():
with mock_s3():
conn = boto.connect_s3()
conn.create_bucket('mybucket')
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome')
model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome'
body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
Raw
~~~
@ -91,13 +93,16 @@ You can also start and stop the mocking manually.
mock = mock_s3()
mock.start()
conn = boto.connect_s3()
conn.create_bucket('mybucket')
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome')
model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome'
body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
mock.stop()

View File

@ -60,6 +60,8 @@ Currently implemented Services:
+---------------------------+-----------------------+------------------------------------+
| EMR | @mock_emr | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Forecast | @mock_forecast | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Glacier | @mock_glacier | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| IAM | @mock_iam | core endpoints done |

View File

@ -1,67 +1,130 @@
from __future__ import unicode_literals
from .acm import mock_acm # noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # noqa
from .athena import mock_athena # noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # noqa
from .batch import mock_batch # noqa
from .cloudformation import mock_cloudformation # noqa
from .cloudformation import mock_cloudformation_deprecated # noqa
from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # noqa
from .codecommit import mock_codecommit # noqa
from .codepipeline import mock_codepipeline # noqa
from .cognitoidentity import mock_cognitoidentity # noqa
from .cognitoidentity import mock_cognitoidentity_deprecated # noqa
from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # noqa
from .config import mock_config # noqa
from .datapipeline import mock_datapipeline # noqa
from .datapipeline import mock_datapipeline_deprecated # noqa
from .datasync import mock_datasync # noqa
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # noqa
from .dynamodbstreams import mock_dynamodbstreams # noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # noqa
from .ec2_instance_connect import mock_ec2_instance_connect # noqa
from .ecr import mock_ecr, mock_ecr_deprecated # noqa
from .ecs import mock_ecs, mock_ecs_deprecated # noqa
from .elb import mock_elb, mock_elb_deprecated # noqa
from .elbv2 import mock_elbv2 # noqa
from .emr import mock_emr, mock_emr_deprecated # noqa
from .events import mock_events # noqa
from .glacier import mock_glacier, mock_glacier_deprecated # noqa
from .glue import mock_glue # noqa
from .iam import mock_iam, mock_iam_deprecated # noqa
from .iot import mock_iot # noqa
from .iotdata import mock_iotdata # noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # noqa
from .kms import mock_kms, mock_kms_deprecated # noqa
from .logs import mock_logs, mock_logs_deprecated # noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # noqa
from .organizations import mock_organizations # noqa
from .polly import mock_polly # noqa
from .rds import mock_rds, mock_rds_deprecated # noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # noqa
from .redshift import mock_redshift, mock_redshift_deprecated # noqa
from .resourcegroups import mock_resourcegroups # noqa
from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # noqa
from .route53 import mock_route53, mock_route53_deprecated # noqa
from .s3 import mock_s3, mock_s3_deprecated # noqa
from .secretsmanager import mock_secretsmanager # noqa
from .ses import mock_ses, mock_ses_deprecated # noqa
from .sns import mock_sns, mock_sns_deprecated # noqa
from .sqs import mock_sqs, mock_sqs_deprecated # noqa
from .ssm import mock_ssm # noqa
from .stepfunctions import mock_stepfunctions # noqa
from .sts import mock_sts, mock_sts_deprecated # noqa
from .swf import mock_swf, mock_swf_deprecated # noqa
from .xray import XRaySegment, mock_xray, mock_xray_client # noqa
import importlib
def lazy_load(module_name, element):
def f(*args, **kwargs):
module = importlib.import_module(module_name, "moto")
return getattr(module, element)(*args, **kwargs)
return f
mock_acm = lazy_load(".acm", "mock_acm")
mock_apigateway = lazy_load(".apigateway", "mock_apigateway")
mock_apigateway_deprecated = lazy_load(".apigateway", "mock_apigateway_deprecated")
mock_athena = lazy_load(".athena", "mock_athena")
mock_applicationautoscaling = lazy_load(
".applicationautoscaling", "mock_applicationautoscaling"
)
mock_autoscaling = lazy_load(".autoscaling", "mock_autoscaling")
mock_autoscaling_deprecated = lazy_load(".autoscaling", "mock_autoscaling_deprecated")
mock_lambda = lazy_load(".awslambda", "mock_lambda")
mock_lambda_deprecated = lazy_load(".awslambda", "mock_lambda_deprecated")
mock_batch = lazy_load(".batch", "mock_batch")
mock_batch = lazy_load(".batch", "mock_batch")
mock_cloudformation = lazy_load(".cloudformation", "mock_cloudformation")
mock_cloudformation_deprecated = lazy_load(
".cloudformation", "mock_cloudformation_deprecated"
)
mock_cloudwatch = lazy_load(".cloudwatch", "mock_cloudwatch")
mock_cloudwatch_deprecated = lazy_load(".cloudwatch", "mock_cloudwatch_deprecated")
mock_codecommit = lazy_load(".codecommit", "mock_codecommit")
mock_codepipeline = lazy_load(".codepipeline", "mock_codepipeline")
mock_cognitoidentity = lazy_load(".cognitoidentity", "mock_cognitoidentity")
mock_cognitoidentity_deprecated = lazy_load(
".cognitoidentity", "mock_cognitoidentity_deprecated"
)
mock_cognitoidp = lazy_load(".cognitoidp", "mock_cognitoidp")
mock_cognitoidp_deprecated = lazy_load(".cognitoidp", "mock_cognitoidp_deprecated")
mock_config = lazy_load(".config", "mock_config")
mock_datapipeline = lazy_load(".datapipeline", "mock_datapipeline")
mock_datapipeline_deprecated = lazy_load(
".datapipeline", "mock_datapipeline_deprecated"
)
mock_datasync = lazy_load(".datasync", "mock_datasync")
mock_dynamodb = lazy_load(".dynamodb", "mock_dynamodb")
mock_dynamodb_deprecated = lazy_load(".dynamodb", "mock_dynamodb_deprecated")
mock_dynamodb2 = lazy_load(".dynamodb2", "mock_dynamodb2")
mock_dynamodb2_deprecated = lazy_load(".dynamodb2", "mock_dynamodb2_deprecated")
mock_dynamodbstreams = lazy_load(".dynamodbstreams", "mock_dynamodbstreams")
mock_elasticbeanstalk = lazy_load(".elasticbeanstalk", "mock_elasticbeanstalk")
mock_ec2 = lazy_load(".ec2", "mock_ec2")
mock_ec2_deprecated = lazy_load(".ec2", "mock_ec2_deprecated")
mock_ec2instanceconnect = lazy_load(".ec2instanceconnect", "mock_ec2instanceconnect")
mock_ecr = lazy_load(".ecr", "mock_ecr")
mock_ecr_deprecated = lazy_load(".ecr", "mock_ecr_deprecated")
mock_ecs = lazy_load(".ecs", "mock_ecs")
mock_ecs_deprecated = lazy_load(".ecs", "mock_ecs_deprecated")
mock_elb = lazy_load(".elb", "mock_elb")
mock_elb_deprecated = lazy_load(".elb", "mock_elb_deprecated")
mock_elbv2 = lazy_load(".elbv2", "mock_elbv2")
mock_emr = lazy_load(".emr", "mock_emr")
mock_emr_deprecated = lazy_load(".emr", "mock_emr_deprecated")
mock_events = lazy_load(".events", "mock_events")
mock_forecast = lazy_load(".forecast", "mock_forecast")
mock_glacier = lazy_load(".glacier", "mock_glacier")
mock_glacier_deprecated = lazy_load(".glacier", "mock_glacier_deprecated")
mock_glue = lazy_load(".glue", "mock_glue")
mock_iam = lazy_load(".iam", "mock_iam")
mock_iam_deprecated = lazy_load(".iam", "mock_iam_deprecated")
mock_iot = lazy_load(".iot", "mock_iot")
mock_iotdata = lazy_load(".iotdata", "mock_iotdata")
mock_kinesis = lazy_load(".kinesis", "mock_kinesis")
mock_kinesis_deprecated = lazy_load(".kinesis", "mock_kinesis_deprecated")
mock_kms = lazy_load(".kms", "mock_kms")
mock_kms_deprecated = lazy_load(".kms", "mock_kms_deprecated")
mock_logs = lazy_load(".logs", "mock_logs")
mock_logs_deprecated = lazy_load(".logs", "mock_logs_deprecated")
mock_managedblockchain = lazy_load(".managedblockchain", "mock_managedblockchain")
mock_opsworks = lazy_load(".opsworks", "mock_opsworks")
mock_opsworks_deprecated = lazy_load(".opsworks", "mock_opsworks_deprecated")
mock_organizations = lazy_load(".organizations", "mock_organizations")
mock_polly = lazy_load(".polly", "mock_polly")
mock_ram = lazy_load(".ram", "mock_ram")
mock_rds = lazy_load(".rds", "mock_rds")
mock_rds_deprecated = lazy_load(".rds", "mock_rds_deprecated")
mock_rds2 = lazy_load(".rds2", "mock_rds2")
mock_rds2_deprecated = lazy_load(".rds2", "mock_rds2_deprecated")
mock_redshift = lazy_load(".redshift", "mock_redshift")
mock_redshift_deprecated = lazy_load(".redshift", "mock_redshift_deprecated")
mock_resourcegroups = lazy_load(".resourcegroups", "mock_resourcegroups")
mock_resourcegroupstaggingapi = lazy_load(
".resourcegroupstaggingapi", "mock_resourcegroupstaggingapi"
)
mock_route53 = lazy_load(".route53", "mock_route53")
mock_route53_deprecated = lazy_load(".route53", "mock_route53_deprecated")
mock_s3 = lazy_load(".s3", "mock_s3")
mock_s3_deprecated = lazy_load(".s3", "mock_s3_deprecated")
mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker")
mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager")
mock_ses = lazy_load(".ses", "mock_ses")
mock_ses_deprecated = lazy_load(".ses", "mock_ses_deprecated")
mock_sns = lazy_load(".sns", "mock_sns")
mock_sns_deprecated = lazy_load(".sns", "mock_sns_deprecated")
mock_sqs = lazy_load(".sqs", "mock_sqs")
mock_sqs_deprecated = lazy_load(".sqs", "mock_sqs_deprecated")
mock_ssm = lazy_load(".ssm", "mock_ssm")
mock_stepfunctions = lazy_load(".stepfunctions", "mock_stepfunctions")
mock_sts = lazy_load(".sts", "mock_sts")
mock_sts_deprecated = lazy_load(".sts", "mock_sts_deprecated")
mock_swf = lazy_load(".swf", "mock_swf")
mock_swf_deprecated = lazy_load(".swf", "mock_swf_deprecated")
mock_transcribe = lazy_load(".transcribe", "mock_transcribe")
XRaySegment = lazy_load(".xray", "XRaySegment")
mock_xray = lazy_load(".xray", "mock_xray")
mock_xray_client = lazy_load(".xray", "mock_xray_client")
mock_kinesisvideo = lazy_load(".kinesisvideo", "mock_kinesisvideo")
mock_kinesisvideoarchivedmedia = lazy_load(
".kinesisvideoarchivedmedia", "mock_kinesisvideoarchivedmedia"
)
# import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = "moto"
__version__ = "1.3.15.dev"
__version__ = "1.3.16.dev"
try:

View File

@ -1,9 +1,9 @@
from __future__ import unicode_literals
import re
import json
import datetime
from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import AWSError
from moto.ec2 import ec2_backends
from .utils import make_arn_for_certificate
@ -50,18 +50,6 @@ def datetime_to_epoch(date):
return int((date - datetime.datetime(1970, 1, 1)).total_seconds())
class AWSError(Exception):
TYPE = None
STATUS = 400
def __init__(self, message):
self.message = message
def response(self):
resp = {"__type": self.TYPE, "message": self.message}
return json.dumps(resp), dict(status=self.STATUS)
class AWSValidationException(AWSError):
TYPE = "ValidationException"
@ -70,6 +58,68 @@ class AWSResourceNotFoundException(AWSError):
TYPE = "ResourceNotFoundException"
class AWSTooManyTagsException(AWSError):
TYPE = "TooManyTagsException"
class TagHolder(dict):
MAX_TAG_COUNT = 50
MAX_KEY_LENGTH = 128
MAX_VALUE_LENGTH = 256
def _validate_kv(self, key, value, index):
if len(key) > self.MAX_KEY_LENGTH:
raise AWSValidationException(
"Value '%s' at 'tags.%d.member.key' failed to satisfy constraint: Member must have length less than or equal to %s"
% (key, index, self.MAX_KEY_LENGTH)
)
if value and len(value) > self.MAX_VALUE_LENGTH:
raise AWSValidationException(
"Value '%s' at 'tags.%d.member.value' failed to satisfy constraint: Member must have length less than or equal to %s"
% (value, index, self.MAX_VALUE_LENGTH)
)
if key.startswith("aws:"):
raise AWSValidationException(
'Invalid Tag Key: "%s". AWS internal tags cannot be changed with this API'
% key
)
def add(self, tags):
tags_copy = self.copy()
for i, tag in enumerate(tags):
key = tag["Key"]
value = tag.get("Value", None)
self._validate_kv(key, value, i + 1)
tags_copy[key] = value
if len(tags_copy) > self.MAX_TAG_COUNT:
raise AWSTooManyTagsException(
"the TagSet: '{%s}' contains too many Tags"
% ", ".join(k + "=" + str(v or "") for k, v in tags_copy.items())
)
self.update(tags_copy)
def remove(self, tags):
for i, tag in enumerate(tags):
key = tag["Key"]
value = tag.get("Value", None)
self._validate_kv(key, value, i + 1)
try:
# If value isnt provided, just delete key
if value is None:
del self[key]
# If value is provided, only delete if it matches what already exists
elif self[key] == value:
del self[key]
except KeyError:
pass
def equals(self, tags):
tags = {t["Key"]: t.get("Value", None) for t in tags} if tags else {}
return self == tags
class CertBundle(BaseModel):
def __init__(
self,
@ -88,7 +138,7 @@ class CertBundle(BaseModel):
self.key = private_key
self._key = None
self.chain = chain
self.tags = {}
self.tags = TagHolder()
self._chain = None
self.type = cert_type # Should really be an enum
self.status = cert_status # Should really be an enum
@ -293,9 +343,12 @@ class CertBundle(BaseModel):
key_algo = "EC_prime256v1"
# Look for SANs
san_obj = self._cert.extensions.get_extension_for_oid(
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
)
try:
san_obj = self._cert.extensions.get_extension_for_oid(
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
)
except cryptography.x509.ExtensionNotFound:
san_obj = None
sans = []
if san_obj is not None:
sans = [item.value for item in san_obj.value]
@ -385,7 +438,7 @@ class AWSCertificateManagerBackend(BaseBackend):
"expires": datetime.datetime.now() + datetime.timedelta(hours=1),
}
def import_cert(self, certificate, private_key, chain=None, arn=None):
def import_cert(self, certificate, private_key, chain=None, arn=None, tags=None):
if arn is not None:
if arn not in self._certificates:
raise self._arn_not_found(arn)
@ -400,6 +453,9 @@ class AWSCertificateManagerBackend(BaseBackend):
self._certificates[bundle.arn] = bundle
if tags:
self.add_tags_to_certificate(bundle.arn, tags)
return bundle.arn
def get_certificates_list(self, statuses):
@ -434,10 +490,11 @@ class AWSCertificateManagerBackend(BaseBackend):
domain_validation_options,
idempotency_token,
subject_alt_names,
tags=None,
):
if idempotency_token is not None:
arn = self._get_arn_from_idempotency_token(idempotency_token)
if arn is not None:
if arn and self._certificates[arn].tags.equals(tags):
return arn
cert = CertBundle.generate_cert(
@ -447,34 +504,20 @@ class AWSCertificateManagerBackend(BaseBackend):
self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert
if tags:
cert.tags.add(tags)
return cert.arn
def add_tags_to_certificate(self, arn, tags):
# get_cert does arn check
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag["Key"]
value = tag.get("Value", None)
cert_bundle.tags[key] = value
cert_bundle.tags.add(tags)
def remove_tags_from_certificate(self, arn, tags):
# get_cert does arn check
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag["Key"]
value = tag.get("Value", None)
try:
# If value isnt provided, just delete key
if value is None:
del cert_bundle.tags[key]
# If value is provided, only delete if it matches what already exists
elif cert_bundle.tags[key] == value:
del cert_bundle.tags[key]
except KeyError:
pass
cert_bundle.tags.remove(tags)
acm_backends = {}

View File

@ -117,6 +117,7 @@ class AWSCertificateManagerResponse(BaseResponse):
private_key = self._get_param("PrivateKey")
chain = self._get_param("CertificateChain") # Optional
current_arn = self._get_param("CertificateArn") # Optional
tags = self._get_param("Tags") # Optional
# Simple parameter decoding. Rather do it here as its a data transport decision not part of the
# actual data
@ -142,7 +143,7 @@ class AWSCertificateManagerResponse(BaseResponse):
try:
arn = self.acm_backend.import_cert(
certificate, private_key, chain=chain, arn=current_arn
certificate, private_key, chain=chain, arn=current_arn, tags=tags
)
except AWSError as err:
return err.response()
@ -210,6 +211,7 @@ class AWSCertificateManagerResponse(BaseResponse):
) # is ignored atm
idempotency_token = self._get_param("IdempotencyToken")
subject_alt_names = self._get_param("SubjectAlternativeNames")
tags = self._get_param("Tags") # Optional
if subject_alt_names is not None and len(subject_alt_names) > 10:
# There is initial AWS limit of 10
@ -227,6 +229,7 @@ class AWSCertificateManagerResponse(BaseResponse):
domain_validation_options,
idempotency_token,
subject_alt_names,
tags,
)
except AWSError as err:
return err.response()

View File

@ -85,6 +85,15 @@ class NoMethodDefined(BadRequestException):
)
class AuthorizerNotFoundException(RESTError):
code = 404
def __init__(self):
super(AuthorizerNotFoundException, self).__init__(
"NotFoundException", "Invalid Authorizer identifier specified"
)
class StageNotFoundException(RESTError):
code = 404
@ -103,6 +112,15 @@ class ApiKeyNotFoundException(RESTError):
)
class UsagePlanNotFoundException(RESTError):
code = 404
def __init__(self):
super(UsagePlanNotFoundException, self).__init__(
"NotFoundException", "Invalid Usage Plan ID specified"
)
class ApiKeyAlreadyExists(RESTError):
code = 409
@ -110,3 +128,57 @@ class ApiKeyAlreadyExists(RESTError):
super(ApiKeyAlreadyExists, self).__init__(
"ConflictException", "API Key already exists"
)
class InvalidDomainName(BadRequestException):
code = 404
def __init__(self):
super(InvalidDomainName, self).__init__(
"BadRequestException", "No Domain Name specified"
)
class DomainNameNotFound(RESTError):
code = 404
def __init__(self):
super(DomainNameNotFound, self).__init__(
"NotFoundException", "Invalid Domain Name specified"
)
class InvalidRestApiId(BadRequestException):
code = 404
def __init__(self):
super(InvalidRestApiId, self).__init__(
"BadRequestException", "No Rest API Id specified"
)
class InvalidModelName(BadRequestException):
code = 404
def __init__(self):
super(InvalidModelName, self).__init__(
"BadRequestException", "No Model Name specified"
)
class RestAPINotFound(RESTError):
code = 404
def __init__(self):
super(RestAPINotFound, self).__init__(
"NotFoundException", "Invalid Rest API Id specified"
)
class ModelNotFound(RESTError):
code = 404
def __init__(self):
super(ModelNotFound, self).__init__(
"NotFoundException", "Invalid Model Name specified"
)

View File

@ -14,12 +14,12 @@ try:
except ImportError:
from urllib.parse import urlparse
import responses
from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
from .utils import create_id
from moto.core.utils import path_url
from moto.sts.models import ACCOUNT_ID
from .exceptions import (
ApiKeyNotFoundException,
UsagePlanNotFoundException,
AwsProxyNotAllowed,
CrossAccountNotAllowed,
IntegrationMethodNotDefined,
@ -28,11 +28,18 @@ from .exceptions import (
InvalidHttpEndpoint,
InvalidResourcePathException,
InvalidRequestInput,
AuthorizerNotFoundException,
StageNotFoundException,
RoleNotSpecified,
NoIntegrationDefined,
NoMethodDefined,
ApiKeyAlreadyExists,
DomainNameNotFound,
InvalidDomainName,
InvalidRestApiId,
InvalidModelName,
RestAPINotFound,
ModelNotFound,
)
STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
@ -48,11 +55,21 @@ class Deployment(BaseModel, dict):
class IntegrationResponse(BaseModel, dict):
def __init__(self, status_code, selection_pattern=None):
self["responseTemplates"] = {"application/json": None}
def __init__(
self,
status_code,
selection_pattern=None,
response_templates=None,
content_handling=None,
):
if response_templates is None:
response_templates = {"application/json": None}
self["responseTemplates"] = response_templates
self["statusCode"] = status_code
if selection_pattern:
self["selectionPattern"] = selection_pattern
if content_handling:
self["contentHandling"] = content_handling
class Integration(BaseModel, dict):
@ -64,8 +81,14 @@ class Integration(BaseModel, dict):
self["requestTemplates"] = request_templates
self["integrationResponses"] = {"200": IntegrationResponse(200)}
def create_integration_response(self, status_code, selection_pattern):
integration_response = IntegrationResponse(status_code, selection_pattern)
def create_integration_response(
self, status_code, selection_pattern, response_templates, content_handling
):
if response_templates == {}:
response_templates = None
integration_response = IntegrationResponse(
status_code, selection_pattern, response_templates, content_handling
)
self["integrationResponses"][status_code] = integration_response
return integration_response
@ -83,14 +106,14 @@ class MethodResponse(BaseModel, dict):
class Method(BaseModel, dict):
def __init__(self, method_type, authorization_type):
def __init__(self, method_type, authorization_type, **kwargs):
super(Method, self).__init__()
self.update(
dict(
httpMethod=method_type,
authorizationType=authorization_type,
authorizerId=None,
apiKeyRequired=None,
apiKeyRequired=kwargs.get("api_key_required") or False,
requestParameters=None,
requestModels=None,
methodIntegration=None,
@ -117,14 +140,15 @@ class Resource(BaseModel):
self.api_id = api_id
self.path_part = path_part
self.parent_id = parent_id
self.resource_methods = {"GET": {}}
self.resource_methods = {}
def to_dict(self):
response = {
"path": self.get_path(),
"id": self.id,
"resourceMethods": self.resource_methods,
}
if self.resource_methods:
response["resourceMethods"] = self.resource_methods
if self.parent_id:
response["parentId"] = self.parent_id
response["pathPart"] = self.path_part
@ -158,8 +182,12 @@ class Resource(BaseModel):
)
return response.status_code, response.text
def add_method(self, method_type, authorization_type):
method = Method(method_type=method_type, authorization_type=authorization_type)
def add_method(self, method_type, authorization_type, api_key_required):
method = Method(
method_type=method_type,
authorization_type=authorization_type,
api_key_required=api_key_required,
)
self.resource_methods[method_type] = method
return method
@ -182,6 +210,54 @@ class Resource(BaseModel):
return self.resource_methods[method_type].pop("methodIntegration")
class Authorizer(BaseModel, dict):
def __init__(self, id, name, authorizer_type, **kwargs):
super(Authorizer, self).__init__()
self["id"] = id
self["name"] = name
self["type"] = authorizer_type
if kwargs.get("provider_arns"):
self["providerARNs"] = kwargs.get("provider_arns")
if kwargs.get("auth_type"):
self["authType"] = kwargs.get("auth_type")
if kwargs.get("authorizer_uri"):
self["authorizerUri"] = kwargs.get("authorizer_uri")
if kwargs.get("authorizer_credentials"):
self["authorizerCredentials"] = kwargs.get("authorizer_credentials")
if kwargs.get("identity_source"):
self["identitySource"] = kwargs.get("identity_source")
if kwargs.get("identity_validation_expression"):
self["identityValidationExpression"] = kwargs.get(
"identity_validation_expression"
)
self["authorizerResultTtlInSeconds"] = kwargs.get("authorizer_result_ttl")
def apply_operations(self, patch_operations):
for op in patch_operations:
if "/authorizerUri" in op["path"]:
self["authorizerUri"] = op["value"]
elif "/authorizerCredentials" in op["path"]:
self["authorizerCredentials"] = op["value"]
elif "/authorizerResultTtlInSeconds" in op["path"]:
self["authorizerResultTtlInSeconds"] = int(op["value"])
elif "/authType" in op["path"]:
self["authType"] = op["value"]
elif "/identitySource" in op["path"]:
self["identitySource"] = op["value"]
elif "/identityValidationExpression" in op["path"]:
self["identityValidationExpression"] = op["value"]
elif "/name" in op["path"]:
self["name"] = op["value"]
elif "/providerARNs" in op["path"]:
# TODO: add and remove
raise Exception('Patch operation for "%s" not implemented' % op["path"])
elif "/type" in op["path"]:
self["type"] = op["value"]
else:
raise Exception('Patch operation "%s" not implemented' % op["op"])
return self
class Stage(BaseModel, dict):
def __init__(
self,
@ -323,10 +399,10 @@ class ApiKey(BaseModel, dict):
self,
name=None,
description=None,
enabled=True,
enabled=False,
generateDistinctId=False,
value=None,
stageKeys=None,
stageKeys=[],
tags=None,
customerId=None,
):
@ -401,15 +477,17 @@ class RestAPI(BaseModel):
self.description = description
self.create_date = int(time.time())
self.api_key_source = kwargs.get("api_key_source") or "HEADER"
self.policy = kwargs.get("policy") or None
self.endpoint_configuration = kwargs.get("endpoint_configuration") or {
"types": ["EDGE"]
}
self.tags = kwargs.get("tags") or {}
self.deployments = {}
self.authorizers = {}
self.stages = {}
self.resources = {}
self.models = {}
self.add_child("/") # Add default child
def __repr__(self):
@ -424,6 +502,7 @@ class RestAPI(BaseModel):
"apiKeySource": self.api_key_source,
"endpointConfiguration": self.endpoint_configuration,
"tags": self.tags,
"policy": self.policy,
}
def add_child(self, path, parent_id=None):
@ -438,6 +517,29 @@ class RestAPI(BaseModel):
self.resources[child_id] = child
return child
def add_model(
self,
name,
description=None,
schema=None,
content_type=None,
cli_input_json=None,
generate_cli_skeleton=None,
):
model_id = create_id()
new_model = Model(
id=model_id,
name=name,
description=description,
schema=schema,
content_type=content_type,
cli_input_json=cli_input_json,
generate_cli_skeleton=generate_cli_skeleton,
)
self.models[name] = new_model
return new_model
def get_resource_for_path(self, path_after_stage_name):
for resource in self.resources.values():
if resource.get_path() == path_after_stage_name:
@ -474,6 +576,34 @@ class RestAPI(BaseModel):
),
)
def create_authorizer(
self,
id,
name,
authorizer_type,
provider_arns=None,
auth_type=None,
authorizer_uri=None,
authorizer_credentials=None,
identity_source=None,
identiy_validation_expression=None,
authorizer_result_ttl=None,
):
authorizer = Authorizer(
id=id,
name=name,
authorizer_type=authorizer_type,
provider_arns=provider_arns,
auth_type=auth_type,
authorizer_uri=authorizer_uri,
authorizer_credentials=authorizer_credentials,
identity_source=identity_source,
identiy_validation_expression=identiy_validation_expression,
authorizer_result_ttl=authorizer_result_ttl,
)
self.authorizers[id] = authorizer
return authorizer
def create_stage(
self,
name,
@ -513,6 +643,9 @@ class RestAPI(BaseModel):
def get_deployment(self, deployment_id):
return self.deployments[deployment_id]
def get_authorizers(self):
return list(self.authorizers.values())
def get_stages(self):
return list(self.stages.values())
@ -523,6 +656,58 @@ class RestAPI(BaseModel):
return self.deployments.pop(deployment_id)
class DomainName(BaseModel, dict):
def __init__(self, domain_name, **kwargs):
super(DomainName, self).__init__()
self["domainName"] = domain_name
self["regionalDomainName"] = domain_name
self["distributionDomainName"] = domain_name
self["domainNameStatus"] = "AVAILABLE"
self["domainNameStatusMessage"] = "Domain Name Available"
self["regionalHostedZoneId"] = "Z2FDTNDATAQYW2"
self["distributionHostedZoneId"] = "Z2FDTNDATAQYW2"
self["certificateUploadDate"] = int(time.time())
if kwargs.get("certificate_name"):
self["certificateName"] = kwargs.get("certificate_name")
if kwargs.get("certificate_arn"):
self["certificateArn"] = kwargs.get("certificate_arn")
if kwargs.get("certificate_body"):
self["certificateBody"] = kwargs.get("certificate_body")
if kwargs.get("tags"):
self["tags"] = kwargs.get("tags")
if kwargs.get("security_policy"):
self["securityPolicy"] = kwargs.get("security_policy")
if kwargs.get("certificate_chain"):
self["certificateChain"] = kwargs.get("certificate_chain")
if kwargs.get("regional_certificate_name"):
self["regionalCertificateName"] = kwargs.get("regional_certificate_name")
if kwargs.get("certificate_private_key"):
self["certificatePrivateKey"] = kwargs.get("certificate_private_key")
if kwargs.get("regional_certificate_arn"):
self["regionalCertificateArn"] = kwargs.get("regional_certificate_arn")
if kwargs.get("endpoint_configuration"):
self["endpointConfiguration"] = kwargs.get("endpoint_configuration")
if kwargs.get("generate_cli_skeleton"):
self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton")
class Model(BaseModel, dict):
def __init__(self, id, name, **kwargs):
super(Model, self).__init__()
self["id"] = id
self["name"] = name
if kwargs.get("description"):
self["description"] = kwargs.get("description")
if kwargs.get("schema"):
self["schema"] = kwargs.get("schema")
if kwargs.get("content_type"):
self["contentType"] = kwargs.get("content_type")
if kwargs.get("cli_input_json"):
self["cliInputJson"] = kwargs.get("cli_input_json")
if kwargs.get("generate_cli_skeleton"):
self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton")
class APIGatewayBackend(BaseBackend):
def __init__(self, region_name):
super(APIGatewayBackend, self).__init__()
@ -530,6 +715,8 @@ class APIGatewayBackend(BaseBackend):
self.keys = {}
self.usage_plans = {}
self.usage_plan_keys = {}
self.domain_names = {}
self.models = {}
self.region_name = region_name
def reset(self):
@ -544,6 +731,7 @@ class APIGatewayBackend(BaseBackend):
api_key_source=None,
endpoint_configuration=None,
tags=None,
policy=None,
):
api_id = create_id()
rest_api = RestAPI(
@ -554,12 +742,15 @@ class APIGatewayBackend(BaseBackend):
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
policy=policy,
)
self.apis[api_id] = rest_api
return rest_api
def get_rest_api(self, function_id):
rest_api = self.apis[function_id]
rest_api = self.apis.get(function_id)
if rest_api is None:
raise RestAPINotFound()
return rest_api
def list_apis(self):
@ -594,11 +785,60 @@ class APIGatewayBackend(BaseBackend):
resource = self.get_resource(function_id, resource_id)
return resource.get_method(method_type)
def create_method(self, function_id, resource_id, method_type, authorization_type):
def create_method(
self,
function_id,
resource_id,
method_type,
authorization_type,
api_key_required=None,
):
resource = self.get_resource(function_id, resource_id)
method = resource.add_method(method_type, authorization_type)
method = resource.add_method(
method_type, authorization_type, api_key_required=api_key_required
)
return method
def get_authorizer(self, restapi_id, authorizer_id):
api = self.get_rest_api(restapi_id)
authorizer = api.authorizers.get(authorizer_id)
if authorizer is None:
raise AuthorizerNotFoundException()
else:
return authorizer
def get_authorizers(self, restapi_id):
api = self.get_rest_api(restapi_id)
return api.get_authorizers()
def create_authorizer(self, restapi_id, name, authorizer_type, **kwargs):
api = self.get_rest_api(restapi_id)
authorizer_id = create_id()
authorizer = api.create_authorizer(
authorizer_id,
name,
authorizer_type,
provider_arns=kwargs.get("provider_arns"),
auth_type=kwargs.get("auth_type"),
authorizer_uri=kwargs.get("authorizer_uri"),
authorizer_credentials=kwargs.get("authorizer_credentials"),
identity_source=kwargs.get("identity_source"),
identiy_validation_expression=kwargs.get("identiy_validation_expression"),
authorizer_result_ttl=kwargs.get("authorizer_result_ttl"),
)
return api.authorizers.get(authorizer["id"])
def update_authorizer(self, restapi_id, authorizer_id, patch_operations):
authorizer = self.get_authorizer(restapi_id, authorizer_id)
if not authorizer:
api = self.get_rest_api(restapi_id)
authorizer = api.authorizers[authorizer_id] = Authorizer()
return authorizer.apply_operations(patch_operations)
def delete_authorizer(self, restapi_id, authorizer_id):
api = self.get_rest_api(restapi_id)
del api.authorizers[authorizer_id]
def get_stage(self, function_id, stage_name):
api = self.get_rest_api(function_id)
stage = api.stages.get(stage_name)
@ -726,12 +966,13 @@ class APIGatewayBackend(BaseBackend):
status_code,
selection_pattern,
response_templates,
content_handling,
):
if response_templates is None:
raise InvalidRequestInput()
integration = self.get_integration(function_id, resource_id, method_type)
integration_response = integration.create_integration_response(
status_code, selection_pattern
status_code, selection_pattern, response_templates, content_handling
)
return integration_response
@ -821,6 +1062,9 @@ class APIGatewayBackend(BaseBackend):
return plans
def get_usage_plan(self, usage_plan_id):
if usage_plan_id not in self.usage_plans:
raise UsagePlanNotFoundException()
return self.usage_plans[usage_plan_id]
def delete_usage_plan(self, usage_plan_id):
@ -853,6 +1097,17 @@ class APIGatewayBackend(BaseBackend):
return list(self.usage_plan_keys[usage_plan_id].values())
def get_usage_plan_key(self, usage_plan_id, key_id):
# first check if is a valid api key
if key_id not in self.keys:
raise ApiKeyNotFoundException()
# then check if is a valid api key and that the key is in the plan
if (
usage_plan_id not in self.usage_plan_keys
or key_id not in self.usage_plan_keys[usage_plan_id]
):
raise UsagePlanNotFoundException()
return self.usage_plan_keys[usage_plan_id][key_id]
def delete_usage_plan_key(self, usage_plan_id, key_id):
@ -866,6 +1121,98 @@ class APIGatewayBackend(BaseBackend):
except Exception:
return False
def create_domain_name(
self,
domain_name,
certificate_name=None,
tags=None,
certificate_arn=None,
certificate_body=None,
certificate_private_key=None,
certificate_chain=None,
regional_certificate_name=None,
regional_certificate_arn=None,
endpoint_configuration=None,
security_policy=None,
generate_cli_skeleton=None,
):
if not domain_name:
raise InvalidDomainName()
new_domain_name = DomainName(
domain_name=domain_name,
certificate_name=certificate_name,
certificate_private_key=certificate_private_key,
certificate_arn=certificate_arn,
certificate_body=certificate_body,
certificate_chain=certificate_chain,
regional_certificate_name=regional_certificate_name,
regional_certificate_arn=regional_certificate_arn,
endpoint_configuration=endpoint_configuration,
tags=tags,
security_policy=security_policy,
generate_cli_skeleton=generate_cli_skeleton,
)
self.domain_names[domain_name] = new_domain_name
return new_domain_name
def get_domain_names(self):
return list(self.domain_names.values())
def get_domain_name(self, domain_name):
domain_info = self.domain_names.get(domain_name)
if domain_info is None:
raise DomainNameNotFound
else:
return self.domain_names[domain_name]
def create_model(
self,
rest_api_id,
name,
content_type,
description=None,
schema=None,
cli_input_json=None,
generate_cli_skeleton=None,
):
if not rest_api_id:
raise InvalidRestApiId
if not name:
raise InvalidModelName
api = self.get_rest_api(rest_api_id)
new_model = api.add_model(
name=name,
description=description,
schema=schema,
content_type=content_type,
cli_input_json=cli_input_json,
generate_cli_skeleton=generate_cli_skeleton,
)
return new_model
def get_models(self, rest_api_id):
if not rest_api_id:
raise InvalidRestApiId
api = self.get_rest_api(rest_api_id)
models = api.models.values()
return list(models)
def get_model(self, rest_api_id, model_name):
if not rest_api_id:
raise InvalidRestApiId
api = self.get_rest_api(rest_api_id)
model = api.models.get(model_name)
if model is None:
raise ModelNotFound
else:
return model
apigateway_backends = {}
for region_name in Session().get_available_regions("apigateway"):

View File

@ -6,13 +6,22 @@ from moto.core.responses import BaseResponse
from .models import apigateway_backends
from .exceptions import (
ApiKeyNotFoundException,
UsagePlanNotFoundException,
BadRequestException,
CrossAccountNotAllowed,
AuthorizerNotFoundException,
StageNotFoundException,
ApiKeyAlreadyExists,
DomainNameNotFound,
InvalidDomainName,
InvalidRestApiId,
InvalidModelName,
RestAPINotFound,
ModelNotFound,
)
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
AUTHORIZER_TYPES = ["TOKEN", "REQUEST", "COGNITO_USER_POOLS"]
ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
@ -51,6 +60,7 @@ class APIGatewayResponse(BaseResponse):
api_key_source = self._get_param("apiKeySource")
endpoint_configuration = self._get_param("endpointConfiguration")
tags = self._get_param("tags")
policy = self._get_param("policy")
# Param validation
if api_key_source and api_key_source not in API_KEY_SOURCES:
@ -86,6 +96,7 @@ class APIGatewayResponse(BaseResponse):
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
policy=policy,
)
return 200, {}, json.dumps(rest_api.to_dict())
@ -145,8 +156,13 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(method)
elif self.method == "PUT":
authorization_type = self._get_param("authorizationType")
api_key_required = self._get_param("apiKeyRequired")
method = self.backend.create_method(
function_id, resource_id, method_type, authorization_type
function_id,
resource_id,
method_type,
authorization_type,
api_key_required,
)
return 200, {}, json.dumps(method)
@ -172,6 +188,88 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(method_response)
def restapis_authorizers(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
restapi_id = url_path_parts[2]
if self.method == "POST":
name = self._get_param("name")
authorizer_type = self._get_param("type")
provider_arns = self._get_param_with_default_value("providerARNs", None)
auth_type = self._get_param_with_default_value("authType", None)
authorizer_uri = self._get_param_with_default_value("authorizerUri", None)
authorizer_credentials = self._get_param_with_default_value(
"authorizerCredentials", None
)
identity_source = self._get_param_with_default_value("identitySource", None)
identiy_validation_expression = self._get_param_with_default_value(
"identityValidationExpression", None
)
authorizer_result_ttl = self._get_param_with_default_value(
"authorizerResultTtlInSeconds", 300
)
# Param validation
if authorizer_type and authorizer_type not in AUTHORIZER_TYPES:
return self.error(
"ValidationException",
(
"1 validation error detected: "
"Value '{authorizer_type}' at 'createAuthorizerInput.type' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[TOKEN, REQUEST, COGNITO_USER_POOLS]"
).format(authorizer_type=authorizer_type),
)
authorizer_response = self.backend.create_authorizer(
restapi_id,
name,
authorizer_type,
provider_arns=provider_arns,
auth_type=auth_type,
authorizer_uri=authorizer_uri,
authorizer_credentials=authorizer_credentials,
identity_source=identity_source,
identiy_validation_expression=identiy_validation_expression,
authorizer_result_ttl=authorizer_result_ttl,
)
elif self.method == "GET":
authorizers = self.backend.get_authorizers(restapi_id)
return 200, {}, json.dumps({"item": authorizers})
return 200, {}, json.dumps(authorizer_response)
def authorizers(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
restapi_id = url_path_parts[2]
authorizer_id = url_path_parts[4]
if self.method == "GET":
try:
authorizer_response = self.backend.get_authorizer(
restapi_id, authorizer_id
)
except AuthorizerNotFoundException as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
authorizer_response = self.backend.update_authorizer(
restapi_id, authorizer_id, patch_operations
)
elif self.method == "DELETE":
self.backend.delete_authorizer(restapi_id, authorizer_id)
return 202, {}, "{}"
return 200, {}, json.dumps(authorizer_response)
def restapis_stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -289,6 +387,7 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "PUT":
selection_pattern = self._get_param("selectionPattern")
response_templates = self._get_param("responseTemplates")
content_handling = self._get_param("contentHandling")
integration_response = self.backend.create_integration_response(
function_id,
resource_id,
@ -296,6 +395,7 @@ class APIGatewayResponse(BaseResponse):
status_code,
selection_pattern,
response_templates,
content_handling,
)
elif self.method == "DELETE":
integration_response = self.backend.delete_integration_response(
@ -349,16 +449,15 @@ class APIGatewayResponse(BaseResponse):
except ApiKeyAlreadyExists as error:
return (
error.code,
self.headers,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
return 201, {}, json.dumps(apikey_response)
elif self.method == "GET":
apikeys_response = self.backend.get_apikeys()
return 200, {}, json.dumps({"item": apikeys_response})
return 200, {}, json.dumps(apikey_response)
def apikey_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -366,6 +465,7 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/")
apikey = url_path_parts[2]
status_code = 200
if self.method == "GET":
apikey_response = self.backend.get_apikey(apikey)
elif self.method == "PATCH":
@ -373,7 +473,9 @@ class APIGatewayResponse(BaseResponse):
apikey_response = self.backend.update_apikey(apikey, patch_operations)
elif self.method == "DELETE":
apikey_response = self.backend.delete_apikey(apikey)
return 200, {}, json.dumps(apikey_response)
status_code = 202
return status_code, {}, json.dumps(apikey_response)
def usage_plans(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -393,7 +495,16 @@ class APIGatewayResponse(BaseResponse):
usage_plan = url_path_parts[2]
if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan(usage_plan)
try:
usage_plan_response = self.backend.get_usage_plan(usage_plan)
except (UsagePlanNotFoundException) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan(usage_plan)
return 200, {}, json.dumps(usage_plan_response)
@ -417,13 +528,11 @@ class APIGatewayResponse(BaseResponse):
error.message, error.error_type
),
)
return 201, {}, json.dumps(usage_plan_response)
elif self.method == "GET":
usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id)
return 200, {}, json.dumps({"item": usage_plans_response})
return 200, {}, json.dumps(usage_plan_response)
def usage_plan_key_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -432,9 +541,147 @@ class APIGatewayResponse(BaseResponse):
key_id = url_path_parts[4]
if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id)
try:
usage_plan_response = self.backend.get_usage_plan_key(
usage_plan_id, key_id
)
except (UsagePlanNotFoundException, ApiKeyNotFoundException) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan_key(
usage_plan_id, key_id
)
return 200, {}, json.dumps(usage_plan_response)
def domain_names(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
try:
if self.method == "GET":
domain_names = self.backend.get_domain_names()
return 200, {}, json.dumps({"item": domain_names})
elif self.method == "POST":
domain_name = self._get_param("domainName")
certificate_name = self._get_param("certificateName")
tags = self._get_param("tags")
certificate_arn = self._get_param("certificateArn")
certificate_body = self._get_param("certificateBody")
certificate_private_key = self._get_param("certificatePrivateKey")
certificate_chain = self._get_param("certificateChain")
regional_certificate_name = self._get_param("regionalCertificateName")
regional_certificate_arn = self._get_param("regionalCertificateArn")
endpoint_configuration = self._get_param("endpointConfiguration")
security_policy = self._get_param("securityPolicy")
generate_cli_skeleton = self._get_param("generateCliSkeleton")
domain_name_resp = self.backend.create_domain_name(
domain_name,
certificate_name,
tags,
certificate_arn,
certificate_body,
certificate_private_key,
certificate_chain,
regional_certificate_name,
regional_certificate_arn,
endpoint_configuration,
security_policy,
generate_cli_skeleton,
)
return 200, {}, json.dumps(domain_name_resp)
except InvalidDomainName as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
def domain_name_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
domain_name = url_path_parts[2]
domain_names = {}
try:
if self.method == "GET":
if domain_name is not None:
domain_names = self.backend.get_domain_name(domain_name)
return 200, {}, json.dumps(domain_names)
except DomainNameNotFound as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
def models(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0]
try:
if self.method == "GET":
models = self.backend.get_models(rest_api_id)
return 200, {}, json.dumps({"item": models})
elif self.method == "POST":
name = self._get_param("name")
description = self._get_param("description")
schema = self._get_param("schema")
content_type = self._get_param("contentType")
cli_input_json = self._get_param("cliInputJson")
generate_cli_skeleton = self._get_param("generateCliSkeleton")
model = self.backend.create_model(
rest_api_id,
name,
content_type,
description,
schema,
cli_input_json,
generate_cli_skeleton,
)
return 200, {}, json.dumps(model)
except (InvalidRestApiId, InvalidModelName, RestAPINotFound) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
def model_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
rest_api_id = url_path_parts[2]
model_name = url_path_parts[4]
model_info = {}
try:
if self.method == "GET":
model_info = self.backend.get_model(rest_api_id, model_name)
return 200, {}, json.dumps(model_info)
except (
ModelNotFound,
RestAPINotFound,
InvalidRestApiId,
InvalidModelName,
) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)

View File

@ -7,18 +7,24 @@ url_paths = {
"{0}/restapis$": APIGatewayResponse().restapis,
"{0}/restapis/(?P<function_id>[^/]+)/?$": APIGatewayResponse().restapis_individual,
"{0}/restapis/(?P<function_id>[^/]+)/resources$": APIGatewayResponse().resources,
"{0}/restapis/(?P<function_id>[^/]+)/authorizers$": APIGatewayResponse().restapis_authorizers,
"{0}/restapis/(?P<function_id>[^/]+)/authorizers/(?P<authorizer_id>[^/]+)/?$": APIGatewayResponse().authorizers,
"{0}/restapis/(?P<function_id>[^/]+)/stages$": APIGatewayResponse().restapis_stages,
"{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$": APIGatewayResponse().stages,
"{0}/restapis/(?P<function_id>[^/]+)/deployments$": APIGatewayResponse().deployments,
"{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$": APIGatewayResponse().individual_deployment,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$": APIGatewayResponse().resource_individual,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$": APIGatewayResponse().resource_methods,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$": APIGatewayResponse().resource_method_responses,
r"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$": APIGatewayResponse().resource_method_responses,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$": APIGatewayResponse().integrations,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$": APIGatewayResponse().integration_responses,
r"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$": APIGatewayResponse().integration_responses,
"{0}/apikeys$": APIGatewayResponse().apikeys,
"{0}/apikeys/(?P<apikey>[^/]+)": APIGatewayResponse().apikey_individual,
"{0}/usageplans$": APIGatewayResponse().usage_plans,
"{0}/domainnames$": APIGatewayResponse().domain_names,
"{0}/restapis/(?P<function_id>[^/]+)/models$": APIGatewayResponse().models,
"{0}/restapis/(?P<function_id>[^/]+)/models/(?P<model_name>[^/]+)/?$": APIGatewayResponse().model_induvidual,
"{0}/domainnames/(?P<domain_name>[^/]+)/?$": APIGatewayResponse().domain_name_induvidual,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$": APIGatewayResponse().usage_plan_individual,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$": APIGatewayResponse().usage_plan_keys,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual,

View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import applicationautoscaling_backends
from ..core.models import base_decorator
applicationautoscaling_backend = applicationautoscaling_backends["us-east-1"]
mock_applicationautoscaling = base_decorator(applicationautoscaling_backends)

View File

@ -0,0 +1,9 @@
from __future__ import unicode_literals
from moto.core.exceptions import JsonRESTError
class AWSValidationException(JsonRESTError):
def __init__(self, message, **kwargs):
super(AWSValidationException, self).__init__(
"ValidationException", message, **kwargs
)

View File

@ -0,0 +1,348 @@
from __future__ import unicode_literals
from moto.core import BaseBackend, BaseModel
from moto.ecs import ecs_backends
from .exceptions import AWSValidationException
from collections import OrderedDict
from enum import Enum, unique
import time
import uuid
@unique
class ServiceNamespaceValueSet(Enum):
APPSTREAM = "appstream"
RDS = "rds"
LAMBDA = "lambda"
CASSANDRA = "cassandra"
DYNAMODB = "dynamodb"
CUSTOM_RESOURCE = "custom-resource"
ELASTICMAPREDUCE = "elasticmapreduce"
EC2 = "ec2"
COMPREHEND = "comprehend"
ECS = "ecs"
SAGEMAKER = "sagemaker"
@unique
class ScalableDimensionValueSet(Enum):
CASSANDRA_TABLE_READ_CAPACITY_UNITS = "cassandra:table:ReadCapacityUnits"
CASSANDRA_TABLE_WRITE_CAPACITY_UNITS = "cassandra:table:WriteCapacityUnits"
DYNAMODB_INDEX_READ_CAPACITY_UNITS = "dynamodb:index:ReadCapacityUnits"
DYNAMODB_INDEX_WRITE_CAPACITY_UNITS = "dynamodb:index:WriteCapacityUnits"
DYNAMODB_TABLE_READ_CAPACITY_UNITS = "dynamodb:table:ReadCapacityUnits"
DYNAMODB_TABLE_WRITE_CAPACITY_UNITS = "dynamodb:table:WriteCapacityUnits"
RDS_CLUSTER_READ_REPLICA_COUNT = "rds:cluster:ReadReplicaCount"
RDS_CLUSTER_CAPACITY = "rds:cluster:Capacity"
COMPREHEND_DOCUMENT_CLASSIFIER_ENDPOINT_DESIRED_INFERENCE_UNITS = (
"comprehend:document-classifier-endpoint:DesiredInferenceUnits"
)
ELASTICMAPREDUCE_INSTANCE_FLEET_ON_DEMAND_CAPACITY = (
"elasticmapreduce:instancefleet:OnDemandCapacity"
)
ELASTICMAPREDUCE_INSTANCE_FLEET_SPOT_CAPACITY = (
"elasticmapreduce:instancefleet:SpotCapacity"
)
ELASTICMAPREDUCE_INSTANCE_GROUP_INSTANCE_COUNT = (
"elasticmapreduce:instancegroup:InstanceCount"
)
LAMBDA_FUNCTION_PROVISIONED_CONCURRENCY = "lambda:function:ProvisionedConcurrency"
APPSTREAM_FLEET_DESIRED_CAPACITY = "appstream:fleet:DesiredCapacity"
CUSTOM_RESOURCE_RESOURCE_TYPE_PROPERTY = "custom-resource:ResourceType:Property"
SAGEMAKER_VARIANT_DESIRED_INSTANCE_COUNT = "sagemaker:variant:DesiredInstanceCount"
EC2_SPOT_FLEET_REQUEST_TARGET_CAPACITY = "ec2:spot-fleet-request:TargetCapacity"
ECS_SERVICE_DESIRED_COUNT = "ecs:service:DesiredCount"
class ApplicationAutoscalingBackend(BaseBackend):
def __init__(self, region, ecs):
super(ApplicationAutoscalingBackend, self).__init__()
self.region = region
self.ecs_backend = ecs
self.targets = OrderedDict()
self.policies = {}
def reset(self):
region = self.region
ecs = self.ecs_backend
self.__dict__ = {}
self.__init__(region, ecs)
@property
def applicationautoscaling_backend(self):
return applicationautoscaling_backends[self.region]
def describe_scalable_targets(
self, namespace, r_ids=None, dimension=None,
):
""" Describe scalable targets. """
if r_ids is None:
r_ids = []
targets = self._flatten_scalable_targets(namespace)
if dimension is not None:
targets = [t for t in targets if t.scalable_dimension == dimension]
if len(r_ids) > 0:
targets = [t for t in targets if t.resource_id in r_ids]
return targets
def _flatten_scalable_targets(self, namespace):
""" Flatten scalable targets for a given service namespace down to a list. """
targets = []
for dimension in self.targets.keys():
for resource_id in self.targets[dimension].keys():
targets.append(self.targets[dimension][resource_id])
targets = [t for t in targets if t.service_namespace == namespace]
return targets
def register_scalable_target(self, namespace, r_id, dimension, **kwargs):
""" Registers or updates a scalable target. """
_ = _target_params_are_valid(namespace, r_id, dimension)
if namespace == ServiceNamespaceValueSet.ECS.value:
_ = self._ecs_service_exists_for_target(r_id)
if self._scalable_target_exists(r_id, dimension):
target = self.targets[dimension][r_id]
target.update(**kwargs)
else:
target = FakeScalableTarget(self, namespace, r_id, dimension, **kwargs)
self._add_scalable_target(target)
return target
def _scalable_target_exists(self, r_id, dimension):
return r_id in self.targets.get(dimension, [])
def _ecs_service_exists_for_target(self, r_id):
"""Raises a ValidationException if an ECS service does not exist
for the specified resource ID.
"""
resource_type, cluster, service = r_id.split("/")
result = self.ecs_backend.describe_services(cluster, [service])
if len(result) != 1:
raise AWSValidationException("ECS service doesn't exist: {}".format(r_id))
return True
def _add_scalable_target(self, target):
if target.scalable_dimension not in self.targets:
self.targets[target.scalable_dimension] = OrderedDict()
if target.resource_id not in self.targets[target.scalable_dimension]:
self.targets[target.scalable_dimension][target.resource_id] = target
return target
def deregister_scalable_target(self, namespace, r_id, dimension):
""" Registers or updates a scalable target. """
if self._scalable_target_exists(r_id, dimension):
del self.targets[dimension][r_id]
else:
raise AWSValidationException(
"No scalable target found for service namespace: {}, resource ID: {}, scalable dimension: {}".format(
namespace, r_id, dimension
)
)
def put_scaling_policy(
self,
policy_name,
service_namespace,
resource_id,
scalable_dimension,
policy_body,
policy_type=None,
):
policy_key = FakeApplicationAutoscalingPolicy.formulate_key(
service_namespace, resource_id, scalable_dimension, policy_name
)
if policy_key in self.policies:
old_policy = self.policies[policy_name]
policy = FakeApplicationAutoscalingPolicy(
region_name=self.region,
policy_name=policy_name,
service_namespace=service_namespace,
resource_id=resource_id,
scalable_dimension=scalable_dimension,
policy_type=policy_type if policy_type else old_policy.policy_type,
policy_body=policy_body if policy_body else old_policy._policy_body,
)
else:
policy = FakeApplicationAutoscalingPolicy(
region_name=self.region,
policy_name=policy_name,
service_namespace=service_namespace,
resource_id=resource_id,
scalable_dimension=scalable_dimension,
policy_type=policy_type,
policy_body=policy_body,
)
self.policies[policy_key] = policy
return policy
def describe_scaling_policies(self, service_namespace, **kwargs):
policy_names = kwargs.get("policy_names")
resource_id = kwargs.get("resource_id")
scalable_dimension = kwargs.get("scalable_dimension")
max_results = kwargs.get("max_results") or 100
next_token = kwargs.get("next_token")
policies = [
policy
for policy in self.policies.values()
if policy.service_namespace == service_namespace
]
if policy_names:
policies = [
policy for policy in policies if policy.policy_name in policy_names
]
if resource_id:
policies = [
policy for policy in policies if policy.resource_id in resource_id
]
if scalable_dimension:
policies = [
policy
for policy in policies
if policy.scalable_dimension in scalable_dimension
]
starting_point = int(next_token) if next_token else 0
ending_point = starting_point + max_results
policies_page = policies[starting_point:ending_point]
new_next_token = str(ending_point) if ending_point < len(policies) else None
return new_next_token, policies_page
def delete_scaling_policy(
self, policy_name, service_namespace, resource_id, scalable_dimension
):
policy_key = FakeApplicationAutoscalingPolicy.formulate_key(
service_namespace, resource_id, scalable_dimension, policy_name
)
if policy_key in self.policies:
del self.policies[policy_key]
return {}
else:
raise AWSValidationException(
"No scaling policy found for service namespace: {}, resource ID: {}, scalable dimension: {}, policy name: {}".format(
service_namespace, resource_id, scalable_dimension, policy_name
)
)
def _target_params_are_valid(namespace, r_id, dimension):
""" Check whether namespace, resource_id and dimension are valid and consistent with each other. """
is_valid = True
valid_namespaces = [n.value for n in ServiceNamespaceValueSet]
if namespace not in valid_namespaces:
is_valid = False
if dimension is not None:
try:
valid_dimensions = [d.value for d in ScalableDimensionValueSet]
d_namespace, d_resource_type, scaling_property = dimension.split(":")
resource_type = _get_resource_type_from_resource_id(r_id)
if (
dimension not in valid_dimensions
or d_namespace != namespace
or resource_type != d_resource_type
):
is_valid = False
except ValueError:
is_valid = False
if not is_valid:
raise AWSValidationException(
"Unsupported service namespace, resource type or scalable dimension"
)
return is_valid
def _get_resource_type_from_resource_id(resource_id):
# AWS Application Autoscaling resource_ids are multi-component (path-like) identifiers that vary in format,
# depending on the type of resource it identifies. resource_type is one of its components.
# resource_id format variations are described in
# https://docs.aws.amazon.com/autoscaling/application/APIReference/API_RegisterScalableTarget.html
# In a nutshell:
# - Most use slash separators, but some use colon separators.
# - The resource type is usually the first component of the resource_id...
# - ...except for sagemaker endpoints, dynamodb GSIs and keyspaces tables, where it's the third.
# - Comprehend uses an arn, with the resource type being the last element.
if resource_id.startswith("arn:aws:comprehend"):
resource_id = resource_id.split(":")[-1]
resource_split = (
resource_id.split("/") if "/" in resource_id else resource_id.split(":")
)
if (
resource_split[0] == "endpoint"
or (resource_split[0] == "table" and len(resource_split) > 2)
or (resource_split[0] == "keyspace")
):
resource_type = resource_split[2]
else:
resource_type = resource_split[0]
return resource_type
class FakeScalableTarget(BaseModel):
def __init__(
self, backend, service_namespace, resource_id, scalable_dimension, **kwargs
):
self.applicationautoscaling_backend = backend
self.service_namespace = service_namespace
self.resource_id = resource_id
self.scalable_dimension = scalable_dimension
self.min_capacity = kwargs["min_capacity"]
self.max_capacity = kwargs["max_capacity"]
self.role_arn = kwargs["role_arn"]
self.suspended_state = kwargs["suspended_state"]
self.creation_time = time.time()
def update(self, **kwargs):
if kwargs["min_capacity"] is not None:
self.min_capacity = kwargs["min_capacity"]
if kwargs["max_capacity"] is not None:
self.max_capacity = kwargs["max_capacity"]
if kwargs["suspended_state"] is not None:
self.suspended_state = kwargs["suspended_state"]
class FakeApplicationAutoscalingPolicy(BaseModel):
def __init__(
self,
region_name,
policy_name,
service_namespace,
resource_id,
scalable_dimension,
policy_type,
policy_body,
):
self.step_scaling_policy_configuration = None
self.target_tracking_scaling_policy_configuration = None
if "policy_type" == "StepScaling":
self.step_scaling_policy_configuration = policy_body
self.target_tracking_scaling_policy_configuration = None
elif policy_type == "TargetTrackingScaling":
self.step_scaling_policy_configuration = None
self.target_tracking_scaling_policy_configuration = policy_body
else:
raise AWSValidationException(
"Unknown policy type {} specified.".format(policy_type)
)
self._policy_body = policy_body
self.service_namespace = service_namespace
self.resource_id = resource_id
self.scalable_dimension = scalable_dimension
self.policy_name = policy_name
self.policy_type = policy_type
self._guid = uuid.uuid4()
self.policy_arn = "arn:aws:autoscaling:{}:scalingPolicy:{}:resource/sagemaker/{}:policyName/{}".format(
region_name, self._guid, self.resource_id, self.policy_name
)
self.creation_time = time.time()
@staticmethod
def formulate_key(service_namespace, resource_id, scalable_dimension, policy_name):
return "{}\t{}\t{}\t{}".format(
service_namespace, resource_id, scalable_dimension, policy_name
)
applicationautoscaling_backends = {}
for region_name, ecs_backend in ecs_backends.items():
applicationautoscaling_backends[region_name] = ApplicationAutoscalingBackend(
region_name, ecs_backend
)

View File

@ -0,0 +1,159 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
import json
from .models import (
applicationautoscaling_backends,
ScalableDimensionValueSet,
ServiceNamespaceValueSet,
)
from .exceptions import AWSValidationException
class ApplicationAutoScalingResponse(BaseResponse):
@property
def applicationautoscaling_backend(self):
return applicationautoscaling_backends[self.region]
def describe_scalable_targets(self):
self._validate_params()
service_namespace = self._get_param("ServiceNamespace")
resource_ids = self._get_param("ResourceIds")
scalable_dimension = self._get_param("ScalableDimension")
max_results = self._get_int_param("MaxResults", 50)
marker = self._get_param("NextToken")
all_scalable_targets = self.applicationautoscaling_backend.describe_scalable_targets(
service_namespace, resource_ids, scalable_dimension
)
start = int(marker) + 1 if marker else 0
next_token = None
scalable_targets_resp = all_scalable_targets[start : start + max_results]
if len(all_scalable_targets) > start + max_results:
next_token = str(len(scalable_targets_resp) - 1)
targets = [_build_target(t) for t in scalable_targets_resp]
return json.dumps({"ScalableTargets": targets, "NextToken": next_token})
def register_scalable_target(self):
""" Registers or updates a scalable target. """
self._validate_params()
self.applicationautoscaling_backend.register_scalable_target(
self._get_param("ServiceNamespace"),
self._get_param("ResourceId"),
self._get_param("ScalableDimension"),
min_capacity=self._get_int_param("MinCapacity"),
max_capacity=self._get_int_param("MaxCapacity"),
role_arn=self._get_param("RoleARN"),
suspended_state=self._get_param("SuspendedState"),
)
return json.dumps({})
def deregister_scalable_target(self):
""" Deregisters a scalable target. """
self._validate_params()
self.applicationautoscaling_backend.deregister_scalable_target(
self._get_param("ServiceNamespace"),
self._get_param("ResourceId"),
self._get_param("ScalableDimension"),
)
return json.dumps({})
def put_scaling_policy(self):
policy = self.applicationautoscaling_backend.put_scaling_policy(
policy_name=self._get_param("PolicyName"),
service_namespace=self._get_param("ServiceNamespace"),
resource_id=self._get_param("ResourceId"),
scalable_dimension=self._get_param("ScalableDimension"),
policy_type=self._get_param("PolicyType"),
policy_body=self._get_param(
"StepScalingPolicyConfiguration",
self._get_param("TargetTrackingScalingPolicyConfiguration"),
),
)
return json.dumps({"PolicyARN": policy.policy_arn, "Alarms": []}) # ToDo
def describe_scaling_policies(self):
(
next_token,
policy_page,
) = self.applicationautoscaling_backend.describe_scaling_policies(
service_namespace=self._get_param("ServiceNamespace"),
resource_id=self._get_param("ResourceId"),
scalable_dimension=self._get_param("ScalableDimension"),
max_results=self._get_param("MaxResults"),
next_token=self._get_param("NextToken"),
)
response_obj = {"ScalingPolicies": [_build_policy(p) for p in policy_page]}
if next_token:
response_obj["NextToken"] = next_token
return json.dumps(response_obj)
def delete_scaling_policy(self):
self.applicationautoscaling_backend.delete_scaling_policy(
policy_name=self._get_param("PolicyName"),
service_namespace=self._get_param("ServiceNamespace"),
resource_id=self._get_param("ResourceId"),
scalable_dimension=self._get_param("ScalableDimension"),
)
return json.dumps({})
def _validate_params(self):
"""Validate parameters.
TODO Integrate this validation with the validation in models.py
"""
namespace = self._get_param("ServiceNamespace")
dimension = self._get_param("ScalableDimension")
messages = []
dimensions = [d.value for d in ScalableDimensionValueSet]
message = None
if dimension is not None and dimension not in dimensions:
messages.append(
"Value '{}' at 'scalableDimension' "
"failed to satisfy constraint: Member must satisfy enum value set: "
"{}".format(dimension, dimensions)
)
namespaces = [n.value for n in ServiceNamespaceValueSet]
if namespace is not None and namespace not in namespaces:
messages.append(
"Value '{}' at 'serviceNamespace' "
"failed to satisfy constraint: Member must satisfy enum value set: "
"{}".format(namespace, namespaces)
)
if len(messages) == 1:
message = "1 validation error detected: {}".format(messages[0])
elif len(messages) > 1:
message = "{} validation errors detected: {}".format(
len(messages), "; ".join(messages)
)
if message:
raise AWSValidationException(message)
def _build_target(t):
return {
"CreationTime": t.creation_time,
"ServiceNamespace": t.service_namespace,
"ResourceId": t.resource_id,
"RoleARN": t.role_arn,
"ScalableDimension": t.scalable_dimension,
"MaxCapacity": t.max_capacity,
"MinCapacity": t.min_capacity,
"SuspendedState": t.suspended_state,
}
def _build_policy(p):
response = {
"PolicyARN": p.policy_arn,
"PolicyName": p.policy_name,
"ServiceNamespace": p.service_namespace,
"ResourceId": p.resource_id,
"ScalableDimension": p.scalable_dimension,
"PolicyType": p.policy_type,
"CreationTime": p.creation_time,
}
if p.policy_type == "StepScaling":
response["StepScalingPolicyConfiguration"] = p.step_scaling_policy_configuration
elif p.policy_type == "TargetTrackingScaling":
response[
"TargetTrackingScalingPolicyConfiguration"
] = p.target_tracking_scaling_policy_configuration
return response

View File

@ -0,0 +1,8 @@
from __future__ import unicode_literals
from .responses import ApplicationAutoScalingResponse
url_bases = ["https?://application-autoscaling.(.+).amazonaws.com"]
url_paths = {
"{0}/$": ApplicationAutoScalingResponse.dispatch,
}

View File

@ -0,0 +1,10 @@
from six.moves.urllib.parse import urlparse
def region_from_applicationautoscaling_url(url):
domain = urlparse(url).netloc
if "." in domain:
return domain.split(".")[1]
else:
return "us-east-1"

View File

@ -2,10 +2,9 @@ from __future__ import unicode_literals
import time
from boto3 import Session
from moto.core import BaseBackend, BaseModel, ACCOUNT_ID
from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID
from uuid import uuid4
class TaggableResourceMixin(object):
@ -50,6 +49,27 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
self.configuration = configuration
class Execution(BaseModel):
def __init__(self, query, context, config, workgroup):
self.id = str(uuid4())
self.query = query
self.context = context
self.config = config
self.workgroup = workgroup
self.start_time = time.time()
self.status = "QUEUED"
class NamedQuery(BaseModel):
def __init__(self, name, description, database, query_string, workgroup):
self.id = str(uuid4())
self.name = name
self.description = description
self.database = database
self.query_string = query_string
self.workgroup = workgroup
class AthenaBackend(BaseBackend):
region_name = None
@ -57,6 +77,8 @@ class AthenaBackend(BaseBackend):
if region_name is not None:
self.region_name = region_name
self.work_groups = {}
self.executions = {}
self.named_queries = {}
def create_work_group(self, name, configuration, description, tags):
if name in self.work_groups:
@ -76,6 +98,46 @@ class AthenaBackend(BaseBackend):
for wg in self.work_groups.values()
]
def get_work_group(self, name):
if name not in self.work_groups:
return None
wg = self.work_groups[name]
return {
"Name": wg.name,
"State": wg.state,
"Configuration": wg.configuration,
"Description": wg.description,
"CreationTime": time.time(),
}
def start_query_execution(self, query, context, config, workgroup):
execution = Execution(
query=query, context=context, config=config, workgroup=workgroup
)
self.executions[execution.id] = execution
return execution.id
def get_execution(self, exec_id):
return self.executions[exec_id]
def stop_query_execution(self, exec_id):
execution = self.executions[exec_id]
execution.status = "CANCELLED"
def create_named_query(self, name, description, database, query_string, workgroup):
nq = NamedQuery(
name=name,
description=description,
database=database,
query_string=query_string,
workgroup=workgroup,
)
self.named_queries[nq.id] = nq
return nq.id
def get_named_query(self, query_id):
return self.named_queries[query_id] if query_id in self.named_queries else None
athena_backends = {}
for region in Session().get_available_regions("athena"):

View File

@ -18,15 +18,7 @@ class AthenaResponse(BaseResponse):
name, configuration, description, tags
)
if not work_group:
return (
json.dumps(
{
"__type": "InvalidRequestException",
"Message": "WorkGroup already exists",
}
),
dict(status=400),
)
return self.error("WorkGroup already exists", 400)
return json.dumps(
{
"CreateWorkGroupResponse": {
@ -39,3 +31,86 @@ class AthenaResponse(BaseResponse):
def list_work_groups(self):
return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()})
def get_work_group(self):
name = self._get_param("WorkGroup")
return json.dumps({"WorkGroup": self.athena_backend.get_work_group(name)})
def start_query_execution(self):
query = self._get_param("QueryString")
context = self._get_param("QueryExecutionContext")
config = self._get_param("ResultConfiguration")
workgroup = self._get_param("WorkGroup")
if workgroup and not self.athena_backend.get_work_group(workgroup):
return self.error("WorkGroup does not exist", 400)
id = self.athena_backend.start_query_execution(
query=query, context=context, config=config, workgroup=workgroup
)
return json.dumps({"QueryExecutionId": id})
def get_query_execution(self):
exec_id = self._get_param("QueryExecutionId")
execution = self.athena_backend.get_execution(exec_id)
result = {
"QueryExecution": {
"QueryExecutionId": exec_id,
"Query": execution.query,
"StatementType": "DDL",
"ResultConfiguration": execution.config,
"QueryExecutionContext": execution.context,
"Status": {
"State": execution.status,
"SubmissionDateTime": execution.start_time,
},
"Statistics": {
"EngineExecutionTimeInMillis": 0,
"DataScannedInBytes": 0,
"TotalExecutionTimeInMillis": 0,
"QueryQueueTimeInMillis": 0,
"QueryPlanningTimeInMillis": 0,
"ServiceProcessingTimeInMillis": 0,
},
"WorkGroup": execution.workgroup,
}
}
return json.dumps(result)
def stop_query_execution(self):
exec_id = self._get_param("QueryExecutionId")
self.athena_backend.stop_query_execution(exec_id)
return json.dumps({})
def error(self, msg, status):
return (
json.dumps({"__type": "InvalidRequestException", "Message": msg,}),
dict(status=status),
)
def create_named_query(self):
name = self._get_param("Name")
description = self._get_param("Description")
database = self._get_param("Database")
query_string = self._get_param("QueryString")
workgroup = self._get_param("WorkGroup")
if workgroup and not self.athena_backend.get_work_group(workgroup):
return self.error("WorkGroup does not exist", 400)
query_id = self.athena_backend.create_named_query(
name, description, database, query_string, workgroup
)
return json.dumps({"NamedQueryId": query_id})
def get_named_query(self):
query_id = self._get_param("NamedQueryId")
nq = self.athena_backend.get_named_query(query_id)
return json.dumps(
{
"NamedQuery": {
"Name": nq.name,
"Description": nq.description,
"Database": nq.database,
"QueryString": nq.query_string,
"NamedQueryId": nq.id,
"WorkGroup": nq.workgroup,
}
}
)

View File

@ -21,3 +21,8 @@ class InvalidInstanceError(AutoscalingClientError):
super(InvalidInstanceError, self).__init__(
"ValidationError", "Instance [{0}] is invalid.".format(instance_id)
)
class ValidationError(AutoscalingClientError):
def __init__(self, message):
super(ValidationError, self).__init__("ValidationError", message)

View File

@ -2,11 +2,15 @@ from __future__ import unicode_literals
import random
from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping
from moto.packages.boto.ec2.blockdevicemapping import (
BlockDeviceType,
BlockDeviceMapping,
)
from moto.ec2.exceptions import InvalidInstanceIdError
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.core.utils import camelcase_to_underscores
from moto.ec2 import ec2_backends
from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
@ -15,6 +19,7 @@ from .exceptions import (
AutoscalingClientError,
ResourceContentionError,
InvalidInstanceError,
ValidationError,
)
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
@ -74,7 +79,7 @@ class FakeScalingPolicy(BaseModel):
)
class FakeLaunchConfiguration(BaseModel):
class FakeLaunchConfiguration(CloudFormationModel):
def __init__(
self,
name,
@ -127,6 +132,15 @@ class FakeLaunchConfiguration(BaseModel):
)
return config
@staticmethod
def cloudformation_name_type():
return "LaunchConfigurationName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-autoscaling-launchconfiguration.html
return "AWS::AutoScaling::LaunchConfiguration"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
@ -215,7 +229,7 @@ class FakeLaunchConfiguration(BaseModel):
return block_device_map
class FakeAutoScalingGroup(BaseModel):
class FakeAutoScalingGroup(CloudFormationModel):
def __init__(
self,
name,
@ -224,6 +238,7 @@ class FakeAutoScalingGroup(BaseModel):
max_size,
min_size,
launch_config_name,
launch_template,
vpc_zone_identifier,
default_cooldown,
health_check_period,
@ -233,10 +248,12 @@ class FakeAutoScalingGroup(BaseModel):
placement_group,
termination_policies,
autoscaling_backend,
ec2_backend,
tags,
new_instances_protected_from_scale_in=False,
):
self.autoscaling_backend = autoscaling_backend
self.ec2_backend = ec2_backend
self.name = name
self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier)
@ -244,10 +261,10 @@ class FakeAutoScalingGroup(BaseModel):
self.max_size = max_size
self.min_size = min_size
self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name
]
self.launch_config_name = launch_config_name
self.launch_template = None
self.launch_config = None
self._set_launch_configuration(launch_config_name, launch_template)
self.default_cooldown = (
default_cooldown if default_cooldown else DEFAULT_COOLDOWN
@ -267,6 +284,9 @@ class FakeAutoScalingGroup(BaseModel):
self.tags = tags if tags else []
self.set_desired_capacity(desired_capacity)
def active_instances(self):
return [x for x in self.instance_states if x.lifecycle_state == "InService"]
def _set_azs_and_vpcs(self, availability_zones, vpc_zone_identifier, update=False):
# for updates, if only AZs are provided, they must not clash with
# the AZs of existing VPCs
@ -298,6 +318,51 @@ class FakeAutoScalingGroup(BaseModel):
self.availability_zones = availability_zones
self.vpc_zone_identifier = vpc_zone_identifier
def _set_launch_configuration(self, launch_config_name, launch_template):
if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name
]
self.launch_config_name = launch_config_name
if launch_template:
launch_template_id = launch_template.get("launch_template_id")
launch_template_name = launch_template.get("launch_template_name")
if not (launch_template_id or launch_template_name) or (
launch_template_id and launch_template_name
):
raise ValidationError(
"Valid requests must contain either launchTemplateId or LaunchTemplateName"
)
if launch_template_id:
self.launch_template = self.ec2_backend.get_launch_template(
launch_template_id
)
elif launch_template_name:
self.launch_template = self.ec2_backend.get_launch_template_by_name(
launch_template_name
)
self.launch_template_version = int(launch_template["version"])
@staticmethod
def __set_string_propagate_at_launch_booleans_on_tags(tags):
bool_to_string = {True: "true", False: "false"}
for tag in tags:
if "PropagateAtLaunch" in tag:
tag["PropagateAtLaunch"] = bool_to_string[tag["PropagateAtLaunch"]]
return tags
@staticmethod
def cloudformation_name_type():
return "AutoScalingGroupName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-autoscaling-autoscalinggroup.html
return "AWS::AutoScaling::AutoScalingGroup"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
@ -305,6 +370,10 @@ class FakeAutoScalingGroup(BaseModel):
properties = cloudformation_json["Properties"]
launch_config_name = properties.get("LaunchConfigurationName")
launch_template = {
camelcase_to_underscores(k): v
for k, v in properties.get("LaunchTemplate", {}).items()
}
load_balancer_names = properties.get("LoadBalancerNames", [])
target_group_arns = properties.get("TargetGroupARNs", [])
@ -316,6 +385,7 @@ class FakeAutoScalingGroup(BaseModel):
max_size=properties.get("MaxSize"),
min_size=properties.get("MinSize"),
launch_config_name=launch_config_name,
launch_template=launch_template,
vpc_zone_identifier=(
",".join(properties.get("VPCZoneIdentifier", [])) or None
),
@ -326,7 +396,9 @@ class FakeAutoScalingGroup(BaseModel):
target_group_arns=target_group_arns,
placement_group=None,
termination_policies=properties.get("TerminationPolicies", []),
tags=properties.get("Tags", []),
tags=cls.__set_string_propagate_at_launch_booleans_on_tags(
properties.get("Tags", [])
),
new_instances_protected_from_scale_in=properties.get(
"NewInstancesProtectedFromScaleIn", False
),
@ -362,6 +434,38 @@ class FakeAutoScalingGroup(BaseModel):
def physical_resource_id(self):
return self.name
@property
def image_id(self):
if self.launch_template:
version = self.launch_template.get_version(self.launch_template_version)
return version.image_id
return self.launch_config.image_id
@property
def instance_type(self):
if self.launch_template:
version = self.launch_template.get_version(self.launch_template_version)
return version.instance_type
return self.launch_config.instance_type
@property
def user_data(self):
if self.launch_template:
version = self.launch_template.get_version(self.launch_template_version)
return version.user_data
return self.launch_config.user_data
@property
def security_groups(self):
if self.launch_template:
version = self.launch_template.get_version(self.launch_template_version)
return version.security_groups
return self.launch_config.security_groups
def update(
self,
availability_zones,
@ -369,6 +473,7 @@ class FakeAutoScalingGroup(BaseModel):
max_size,
min_size,
launch_config_name,
launch_template,
vpc_zone_identifier,
default_cooldown,
health_check_period,
@ -390,11 +495,8 @@ class FakeAutoScalingGroup(BaseModel):
if max_size is not None and max_size < len(self.instance_states):
desired_capacity = max_size
if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name
]
self.launch_config_name = launch_config_name
self._set_launch_configuration(launch_config_name, launch_template)
if health_check_period is not None:
self.health_check_period = health_check_period
if health_check_type is not None:
@ -413,12 +515,11 @@ class FakeAutoScalingGroup(BaseModel):
else:
self.desired_capacity = new_capacity
curr_instance_count = len(self.instance_states)
curr_instance_count = len(self.active_instances())
if self.desired_capacity == curr_instance_count:
return
if self.desired_capacity > curr_instance_count:
pass # Nothing to do here
elif self.desired_capacity > curr_instance_count:
# Need more instances
count_needed = int(self.desired_capacity) - int(curr_instance_count)
@ -442,6 +543,9 @@ class FakeAutoScalingGroup(BaseModel):
self.instance_states = list(
set(self.instance_states) - set(instances_to_remove)
)
if self.name in self.autoscaling_backend.autoscaling_groups:
self.autoscaling_backend.update_attached_elbs(self.name)
self.autoscaling_backend.update_attached_target_groups(self.name)
def get_propagated_tags(self):
propagated_tags = {}
@ -450,18 +554,19 @@ class FakeAutoScalingGroup(BaseModel):
# boto3 and cloudformation use PropagateAtLaunch
if "propagate_at_launch" in tag and tag["propagate_at_launch"] == "true":
propagated_tags[tag["key"]] = tag["value"]
if "PropagateAtLaunch" in tag and tag["PropagateAtLaunch"]:
if "PropagateAtLaunch" in tag and tag["PropagateAtLaunch"] == "true":
propagated_tags[tag["Key"]] = tag["Value"]
return propagated_tags
def replace_autoscaling_group_instances(self, count_needed, propagated_tags):
propagated_tags[ASG_NAME_TAG] = self.name
reservation = self.autoscaling_backend.ec2_backend.add_instances(
self.launch_config.image_id,
self.image_id,
count_needed,
self.launch_config.user_data,
self.launch_config.security_groups,
instance_type=self.launch_config.instance_type,
self.user_data,
self.security_groups,
instance_type=self.instance_type,
tags={"instance": propagated_tags},
placement=random.choice(self.availability_zones),
)
@ -553,6 +658,7 @@ class AutoScalingBackend(BaseBackend):
max_size,
min_size,
launch_config_name,
launch_template,
vpc_zone_identifier,
default_cooldown,
health_check_period,
@ -576,7 +682,19 @@ class AutoScalingBackend(BaseBackend):
health_check_period = 300
else:
health_check_period = make_int(health_check_period)
if launch_config_name is None and instance_id is not None:
# TODO: Add MixedInstancesPolicy once implemented.
# Verify only a single launch config-like parameter is provided.
params = [launch_config_name, launch_template, instance_id]
num_params = sum([1 for param in params if param])
if num_params != 1:
raise ValidationError(
"Valid requests must contain either LaunchTemplate, LaunchConfigurationName, "
"InstanceId or MixedInstancesPolicy parameter."
)
if instance_id:
try:
instance = self.ec2_backend.get_instance(instance_id)
launch_config_name = name
@ -593,6 +711,7 @@ class AutoScalingBackend(BaseBackend):
max_size=max_size,
min_size=min_size,
launch_config_name=launch_config_name,
launch_template=launch_template,
vpc_zone_identifier=vpc_zone_identifier,
default_cooldown=default_cooldown,
health_check_period=health_check_period,
@ -602,6 +721,7 @@ class AutoScalingBackend(BaseBackend):
placement_group=placement_group,
termination_policies=termination_policies,
autoscaling_backend=self,
ec2_backend=self.ec2_backend,
tags=tags,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in,
)
@ -619,6 +739,7 @@ class AutoScalingBackend(BaseBackend):
max_size,
min_size,
launch_config_name,
launch_template,
vpc_zone_identifier,
default_cooldown,
health_check_period,
@ -627,19 +748,28 @@ class AutoScalingBackend(BaseBackend):
termination_policies,
new_instances_protected_from_scale_in=None,
):
# TODO: Add MixedInstancesPolicy once implemented.
# Verify only a single launch config-like parameter is provided.
if launch_config_name and launch_template:
raise ValidationError(
"Valid requests must contain either LaunchTemplate, LaunchConfigurationName "
"or MixedInstancesPolicy parameter."
)
group = self.autoscaling_groups[name]
group.update(
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
availability_zones=availability_zones,
desired_capacity=desired_capacity,
max_size=max_size,
min_size=min_size,
launch_config_name=launch_config_name,
launch_template=launch_template,
vpc_zone_identifier=vpc_zone_identifier,
default_cooldown=default_cooldown,
health_check_period=health_check_period,
health_check_type=health_check_type,
placement_group=placement_group,
termination_policies=termination_policies,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in,
)
return group
@ -655,10 +785,16 @@ class AutoScalingBackend(BaseBackend):
self.set_desired_capacity(group_name, 0)
self.autoscaling_groups.pop(group_name, None)
def describe_auto_scaling_instances(self):
def describe_auto_scaling_instances(self, instance_ids):
instance_states = []
for group in self.autoscaling_groups.values():
instance_states.extend(group.instance_states)
instance_states.extend(
[
x
for x in group.instance_states
if not instance_ids or x.instance.id in instance_ids
]
)
return instance_states
def attach_instances(self, group_name, instance_ids):
@ -682,6 +818,7 @@ class AutoScalingBackend(BaseBackend):
)
group.instance_states.extend(new_instances)
self.update_attached_elbs(group.name)
self.update_attached_target_groups(group.name)
def set_instance_health(
self, instance_id, health_status, should_respect_grace_period
@ -697,7 +834,7 @@ class AutoScalingBackend(BaseBackend):
def detach_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name]
original_size = len(group.instance_states)
original_size = group.desired_capacity
detached_instances = [
x for x in group.instance_states if x.instance.id in instance_ids
@ -714,13 +851,8 @@ class AutoScalingBackend(BaseBackend):
if should_decrement:
group.desired_capacity = original_size - len(instance_ids)
else:
count_needed = len(instance_ids)
group.replace_autoscaling_group_instances(
count_needed, group.get_propagated_tags()
)
self.update_attached_elbs(group_name)
group.set_desired_capacity(group.desired_capacity)
return detached_instances
def set_desired_capacity(self, group_name, desired_capacity):
@ -734,7 +866,7 @@ class AutoScalingBackend(BaseBackend):
self.set_desired_capacity(group_name, desired_capacity)
def change_capacity_percent(self, group_name, scaling_adjustment):
""" http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/as-scale-based-on-demand.html
"""http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/as-scale-based-on-demand.html
If PercentChangeInCapacity returns a value between 0 and 1,
Auto Scaling will round it off to 1. If the PercentChangeInCapacity
returns a value greater than 1, Auto Scaling will round it off to the
@ -785,7 +917,9 @@ class AutoScalingBackend(BaseBackend):
def update_attached_elbs(self, group_name):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(state.instance.id for state in group.instance_states)
group_instance_ids = set(
state.instance.id for state in group.active_instances()
)
# skip this if group.load_balancers is empty
# otherwise elb_backend.describe_load_balancers returns all available load balancers
@ -902,15 +1036,15 @@ class AutoScalingBackend(BaseBackend):
autoscaling_group_name,
autoscaling_group,
) in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states)
original_active_instance_count = len(autoscaling_group.active_instances())
autoscaling_group.instance_states = list(
filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states,
)
)
difference = original_instance_count - len(
autoscaling_group.instance_states
difference = original_active_instance_count - len(
autoscaling_group.active_instances()
)
if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(
@ -918,6 +1052,45 @@ class AutoScalingBackend(BaseBackend):
)
self.update_attached_elbs(autoscaling_group_name)
def enter_standby_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name]
original_size = group.desired_capacity
standby_instances = []
for instance_state in group.instance_states:
if instance_state.instance.id in instance_ids:
instance_state.lifecycle_state = "Standby"
standby_instances.append(instance_state)
if should_decrement:
group.desired_capacity = group.desired_capacity - len(instance_ids)
group.set_desired_capacity(group.desired_capacity)
return standby_instances, original_size, group.desired_capacity
def exit_standby_instances(self, group_name, instance_ids):
group = self.autoscaling_groups[group_name]
original_size = group.desired_capacity
standby_instances = []
for instance_state in group.instance_states:
if instance_state.instance.id in instance_ids:
instance_state.lifecycle_state = "InService"
standby_instances.append(instance_state)
group.desired_capacity = group.desired_capacity + len(instance_ids)
group.set_desired_capacity(group.desired_capacity)
return standby_instances, original_size, group.desired_capacity
def terminate_instance(self, instance_id, should_decrement):
instance = self.ec2_backend.get_instance(instance_id)
instance_state = next(
instance_state
for group in self.autoscaling_groups.values()
for instance_state in group.instance_states
if instance_state.instance.id == instance.id
)
group = instance.autoscaling_group
original_size = group.desired_capacity
self.detach_instances(group.name, [instance.id], should_decrement)
self.ec2_backend.terminate_instances([instance.id])
return instance_state, original_size, group.desired_capacity
autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items():

View File

@ -1,7 +1,12 @@
from __future__ import unicode_literals
import datetime
from moto.core.responses import BaseResponse
from moto.core.utils import amz_crc32, amzn_request_id
from moto.core.utils import (
amz_crc32,
amzn_request_id,
iso_8601_datetime_with_milliseconds,
)
from .models import autoscaling_backends
@ -76,6 +81,7 @@ class AutoScalingResponse(BaseResponse):
min_size=self._get_int_param("MinSize"),
instance_id=self._get_param("InstanceId"),
launch_config_name=self._get_param("LaunchConfigurationName"),
launch_template=self._get_dict_param("LaunchTemplate."),
vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
default_cooldown=self._get_int_param("DefaultCooldown"),
health_check_period=self._get_int_param("HealthCheckGracePeriod"),
@ -192,6 +198,7 @@ class AutoScalingResponse(BaseResponse):
max_size=self._get_int_param("MaxSize"),
min_size=self._get_int_param("MinSize"),
launch_config_name=self._get_param("LaunchConfigurationName"),
launch_template=self._get_dict_param("LaunchTemplate."),
vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
default_cooldown=self._get_int_param("DefaultCooldown"),
health_check_period=self._get_int_param("HealthCheckGracePeriod"),
@ -226,7 +233,9 @@ class AutoScalingResponse(BaseResponse):
return template.render()
def describe_auto_scaling_instances(self):
instance_states = self.autoscaling_backend.describe_auto_scaling_instances()
instance_states = self.autoscaling_backend.describe_auto_scaling_instances(
instance_ids=self._get_multi_param("InstanceIds.member")
)
template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
return template.render(instance_states=instance_states)
@ -289,6 +298,50 @@ class AutoScalingResponse(BaseResponse):
template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def enter_standby(self):
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
if should_decrement_string == "true":
should_decrement = True
else:
should_decrement = False
(
standby_instances,
original_size,
desired_capacity,
) = self.autoscaling_backend.enter_standby_instances(
group_name, instance_ids, should_decrement
)
template = self.response_template(ENTER_STANDBY_TEMPLATE)
return template.render(
standby_instances=standby_instances,
should_decrement=should_decrement,
original_size=original_size,
desired_capacity=desired_capacity,
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
)
@amz_crc32
@amzn_request_id
def exit_standby(self):
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
(
standby_instances,
original_size,
desired_capacity,
) = self.autoscaling_backend.exit_standby_instances(group_name, instance_ids)
template = self.response_template(EXIT_STANDBY_TEMPLATE)
return template.render(
standby_instances=standby_instances,
original_size=original_size,
desired_capacity=desired_capacity,
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
)
def suspend_processes(self):
autoscaling_group_name = self._get_param("AutoScalingGroupName")
scaling_processes = self._get_multi_param("ScalingProcesses.member")
@ -308,6 +361,29 @@ class AutoScalingResponse(BaseResponse):
template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def terminate_instance_in_auto_scaling_group(self):
instance_id = self._get_param("InstanceId")
should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
if should_decrement_string == "true":
should_decrement = True
else:
should_decrement = False
(
instance,
original_size,
desired_capacity,
) = self.autoscaling_backend.terminate_instance(instance_id, should_decrement)
template = self.response_template(TERMINATE_INSTANCES_TEMPLATE)
return template.render(
instance=instance,
should_decrement=should_decrement,
original_size=original_size,
desired_capacity=desired_capacity,
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
)
CREATE_LAUNCH_CONFIGURATION_TEMPLATE = """<CreateLaunchConfigurationResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<ResponseMetadata>
@ -499,14 +575,31 @@ DESCRIBE_AUTOSCALING_GROUPS_TEMPLATE = """<DescribeAutoScalingGroupsResponse xml
<HealthCheckType>{{ group.health_check_type }}</HealthCheckType>
<CreatedTime>2013-05-06T17:47:15.107Z</CreatedTime>
<EnabledMetrics/>
{% if group.launch_config_name %}
<LaunchConfigurationName>{{ group.launch_config_name }}</LaunchConfigurationName>
{% elif group.launch_template %}
<LaunchTemplate>
<LaunchTemplateId>{{ group.launch_template.id }}</LaunchTemplateId>
<Version>{{ group.launch_template_version }}</Version>
<LaunchTemplateName>{{ group.launch_template.name }}</LaunchTemplateName>
</LaunchTemplate>
{% endif %}
<Instances>
{% for instance_state in group.instance_states %}
<member>
<HealthStatus>{{ instance_state.health_status }}</HealthStatus>
<AvailabilityZone>{{ instance_state.instance.placement }}</AvailabilityZone>
<InstanceId>{{ instance_state.instance.id }}</InstanceId>
<InstanceType>{{ instance_state.instance.instance_type }}</InstanceType>
{% if group.launch_config_name %}
<LaunchConfigurationName>{{ group.launch_config_name }}</LaunchConfigurationName>
{% elif group.launch_template %}
<LaunchTemplate>
<LaunchTemplateId>{{ group.launch_template.id }}</LaunchTemplateId>
<Version>{{ group.launch_template_version }}</Version>
<LaunchTemplateName>{{ group.launch_template.name }}</LaunchTemplateName>
</LaunchTemplate>
{% endif %}
<LifecycleState>{{ instance_state.lifecycle_state }}</LifecycleState>
<ProtectedFromScaleIn>{{ instance_state.protected_from_scale_in|string|lower }}</ProtectedFromScaleIn>
</member>
@ -592,7 +685,16 @@ DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE = """<DescribeAutoScalingInstancesRespon
<AutoScalingGroupName>{{ instance_state.instance.autoscaling_group.name }}</AutoScalingGroupName>
<AvailabilityZone>{{ instance_state.instance.placement }}</AvailabilityZone>
<InstanceId>{{ instance_state.instance.id }}</InstanceId>
<InstanceType>{{ instance_state.instance.instance_type }}</InstanceType>
{% if instance_state.instance.autoscaling_group.launch_config_name %}
<LaunchConfigurationName>{{ instance_state.instance.autoscaling_group.launch_config_name }}</LaunchConfigurationName>
{% elif instance_state.instance.autoscaling_group.launch_template %}
<LaunchTemplate>
<LaunchTemplateId>{{ instance_state.instance.autoscaling_group.launch_template.id }}</LaunchTemplateId>
<Version>{{ instance_state.instance.autoscaling_group.launch_template_version }}</Version>
<LaunchTemplateName>{{ instance_state.instance.autoscaling_group.launch_template.name }}</LaunchTemplateName>
</LaunchTemplate>
{% endif %}
<LifecycleState>{{ instance_state.lifecycle_state }}</LifecycleState>
<ProtectedFromScaleIn>{{ instance_state.protected_from_scale_in|string|lower }}</ProtectedFromScaleIn>
</member>
@ -705,3 +807,73 @@ SET_INSTANCE_PROTECTION_TEMPLATE = """<SetInstanceProtectionResponse xmlns="http
<RequestId></RequestId>
</ResponseMetadata>
</SetInstanceProtectionResponse>"""
ENTER_STANDBY_TEMPLATE = """<EnterStandbyResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<EnterStandbyResult>
<Activities>
{% for instance in standby_instances %}
<member>
<ActivityId>12345678-1234-1234-1234-123456789012</ActivityId>
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
{% if should_decrement %}
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved to standby in response to a user request, shrinking the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
{% else %}
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved to standby in response to a user request.</Cause>
{% endif %}
<Description>Moving EC2 instance to Standby: {{ instance.instance.id }}</Description>
<Progress>50</Progress>
<StartTime>{{ timestamp }}</StartTime>
<Details>{&quot;Subnet ID&quot;:&quot;??&quot;,&quot;Availability Zone&quot;:&quot;{{ instance.instance.placement }}&quot;}</Details>
<StatusCode>InProgress</StatusCode>
</member>
{% endfor %}
</Activities>
</EnterStandbyResult>
<ResponseMetadata>
<RequestId>7c6e177f-f082-11e1-ac58-3714bEXAMPLE</RequestId>
</ResponseMetadata>
</EnterStandbyResponse>"""
EXIT_STANDBY_TEMPLATE = """<ExitStandbyResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<ExitStandbyResult>
<Activities>
{% for instance in standby_instances %}
<member>
<ActivityId>12345678-1234-1234-1234-123456789012</ActivityId>
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
<Description>Moving EC2 instance out of Standby: {{ instance.instance.id }}</Description>
<Progress>30</Progress>
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved out of standby in response to a user request, increasing the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
<StartTime>{{ timestamp }}</StartTime>
<Details>{&quot;Subnet ID&quot;:&quot;??&quot;,&quot;Availability Zone&quot;:&quot;{{ instance.instance.placement }}&quot;}</Details>
<StatusCode>PreInService</StatusCode>
</member>
{% endfor %}
</Activities>
</ExitStandbyResult>
<ResponseMetadata>
<RequestId>7c6e177f-f082-11e1-ac58-3714bEXAMPLE</RequestId>
</ResponseMetadata>
</ExitStandbyResponse>"""
TERMINATE_INSTANCES_TEMPLATE = """<TerminateInstanceInAutoScalingGroupResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<TerminateInstanceInAutoScalingGroupResult>
<Activity>
<ActivityId>35b5c464-0b63-2fc7-1611-467d4a7f2497EXAMPLE</ActivityId>
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
{% if should_decrement %}
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was taken out of service in response to a user request, shrinking the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
{% else %}
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was taken out of service in response to a user request.</Cause>
{% endif %}
<Description>Terminating EC2 instance: {{ instance.instance.id }}</Description>
<Progress>0</Progress>
<StartTime>{{ timestamp }}</StartTime>
<Details>{&quot;Subnet ID&quot;:&quot;??&quot;,&quot;Availability Zone&quot;:&quot;{{ instance.instance.placement }}&quot;}</Details>
<StatusCode>InProgress</StatusCode>
</Activity>
</TerminateInstanceInAutoScalingGroupResult>
<ResponseMetadata>
<RequestId>a1ba8fb9-31d6-4d9a-ace1-a7f76749df11EXAMPLE</RequestId>
</ResponseMetadata>
</TerminateInstanceInAutoScalingGroupResponse>"""

View File

@ -5,6 +5,8 @@ import time
from collections import defaultdict
import copy
import datetime
from gzip import GzipFile
import docker
import docker.errors
import hashlib
@ -15,18 +17,17 @@ import json
import re
import zipfile
import uuid
import functools
import tarfile
import calendar
import threading
import traceback
import weakref
import requests.adapters
import requests.exceptions
from boto3 import Session
from moto.awslambda.policy import Policy
from moto.core import BaseBackend, BaseModel
from moto.core import BaseBackend, CloudFormationModel
from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend
from moto.iam.exceptions import IAMNotFoundException
@ -45,6 +46,7 @@ from moto.sqs import sqs_backends
from moto.dynamodb2 import dynamodb_backends2
from moto.dynamodbstreams import dynamodbstreams_backends
from moto.core import ACCOUNT_ID
from moto.utilities.docker_utilities import DockerModel
logger = logging.getLogger(__name__)
@ -53,7 +55,6 @@ try:
except ImportError:
from backports.tempfile import TemporaryDirectory
_orig_adapter_send = requests.adapters.HTTPAdapter.send
docker_3 = docker.__version__[0] >= "3"
@ -149,8 +150,9 @@ class _DockerDataVolumeContext:
raise # multiple processes trying to use same volume?
class LambdaFunction(BaseModel):
class LambdaFunction(CloudFormationModel, DockerModel):
def __init__(self, spec, region, validate_s3=True, version=1):
DockerModel.__init__(self)
# required
self.region = region
self.code = spec["Code"]
@ -160,23 +162,9 @@ class LambdaFunction(BaseModel):
self.run_time = spec["Runtime"]
self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.docker_client = docker.from_env()
self.policy = None
self.state = "Active"
# Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked
if requests.adapters.HTTPAdapter.send != _orig_adapter_send:
_orig_get_adapter = self.docker_client.api.get_adapter
def replace_adapter_send(*args, **kwargs):
adapter = _orig_get_adapter(*args, **kwargs)
if isinstance(adapter, requests.adapters.HTTPAdapter):
adapter.send = functools.partial(_orig_adapter_send, adapter)
return adapter
self.docker_client.api.get_adapter = replace_adapter_send
self.reserved_concurrency = spec.get("ReservedConcurrentExecutions", None)
# optional
self.description = spec.get("Description", "")
@ -216,7 +204,7 @@ class LambdaFunction(BaseModel):
key = None
try:
# FIXME: does not validate bucket region
key = s3_backend.get_key(self.code["S3Bucket"], self.code["S3Key"])
key = s3_backend.get_object(self.code["S3Bucket"], self.code["S3Key"])
except MissingBucket:
if do_validate_s3():
raise InvalidParameterValueException(
@ -283,7 +271,7 @@ class LambdaFunction(BaseModel):
return config
def get_code(self):
return {
code = {
"Code": {
"Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(
self.region, self.code["S3Key"]
@ -292,6 +280,15 @@ class LambdaFunction(BaseModel):
},
"Configuration": self.get_configuration(),
}
if self.reserved_concurrency:
code.update(
{
"Concurrency": {
"ReservedConcurrentExecutions": self.reserved_concurrency
}
}
)
return code
def update_configuration(self, config_updates):
for key, value in config_updates.items():
@ -308,7 +305,7 @@ class LambdaFunction(BaseModel):
elif key == "Timeout":
self.timeout = value
elif key == "VpcConfig":
self.vpc_config = value
self._vpc_config = value
elif key == "Environment":
self.environment_vars = value["Variables"]
@ -342,7 +339,7 @@ class LambdaFunction(BaseModel):
key = None
try:
# FIXME: does not validate bucket region
key = s3_backend.get_key(
key = s3_backend.get_object(
updated_spec["S3Bucket"], updated_spec["S3Key"]
)
except MissingBucket:
@ -379,25 +376,32 @@ class LambdaFunction(BaseModel):
event = dict()
if context is None:
context = {}
output = None
try:
# TODO: I believe we can keep the container running and feed events as needed
# also need to hook it up to the other services so it can make kws/s3 etc calls
# Should get invoke_id /RequestId from invocation
env_vars = {
"_HANDLER": self.handler,
"AWS_EXECUTION_ENV": "AWS_Lambda_{}".format(self.run_time),
"AWS_LAMBDA_FUNCTION_TIMEOUT": self.timeout,
"AWS_LAMBDA_FUNCTION_NAME": self.function_name,
"AWS_LAMBDA_FUNCTION_MEMORY_SIZE": self.memory_size,
"AWS_LAMBDA_FUNCTION_VERSION": self.version,
"AWS_REGION": self.region,
"AWS_ACCESS_KEY_ID": "role-account-id",
"AWS_SECRET_ACCESS_KEY": "role-secret-key",
"AWS_SESSION_TOKEN": "session-token",
}
env_vars.update(self.environment_vars)
container = output = exit_code = None
container = exit_code = None
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
with _DockerDataVolumeContext(self) as data_vol:
try:
self.docker_client.ping() # Verify Docker is running
run_kwargs = (
dict(links={"motoserver": "motoserver"})
if settings.TEST_SERVER_MODE
@ -455,24 +459,31 @@ class LambdaFunction(BaseModel):
# We only care about the response from the lambda
# Which is the last line of the output, according to https://github.com/lambci/docker-lambda/issues/25
output = output.splitlines()[-1]
return output, False
resp = output.splitlines()[-1]
logs = os.linesep.join(
[line for line in self.convert(output).splitlines()[:-1]]
)
return resp, False, logs
except docker.errors.DockerException as e:
# Docker itself is probably not running - there will be no Lambda-logs to handle
return "error running docker: {}".format(e), True, ""
except BaseException as e:
traceback.print_exc()
return "error running lambda: {}".format(e), True
logs = os.linesep.join(
[line for line in self.convert(output).splitlines()[:-1]]
)
return "error running lambda: {}".format(e), True, logs
def invoke(self, body, request_headers, response_headers):
payload = dict()
if body:
body = json.loads(body)
# Get the invocation type:
res, errored = self._invoke_lambda(code=self.code, event=body)
res, errored, logs = self._invoke_lambda(code=self.code, event=body)
if request_headers.get("x-amz-invocation-type") == "RequestResponse":
encoded = base64.b64encode(res.encode("utf-8"))
encoded = base64.b64encode(logs.encode("utf-8"))
response_headers["x-amz-log-result"] = encoded.decode("utf-8")
payload["result"] = response_headers["x-amz-log-result"]
result = res.encode("utf-8")
else:
result = res
@ -481,11 +492,29 @@ class LambdaFunction(BaseModel):
return result
@staticmethod
def cloudformation_name_type():
return "FunctionName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-lambda-function.html
return "AWS::Lambda::Function"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
optional_properties = (
"Description",
"MemorySize",
"Publish",
"Timeout",
"VpcConfig",
"Environment",
"ReservedConcurrentExecutions",
)
# required
spec = {
@ -495,9 +524,7 @@ class LambdaFunction(BaseModel):
"Role": properties["Role"],
"Runtime": properties["Runtime"],
}
optional_properties = (
"Description MemorySize Publish Timeout VpcConfig Environment".split()
)
# NOTE: Not doing `properties.get(k, DEFAULT)` to avoid duplicating the
# default logic
for prop in optional_properties:
@ -545,43 +572,66 @@ class LambdaFunction(BaseModel):
lambda_backends[region].delete_function(self.function_name)
class EventSourceMapping(BaseModel):
class EventSourceMapping(CloudFormationModel):
def __init__(self, spec):
# required
self.function_arn = spec["FunctionArn"]
self.function_name = spec["FunctionName"]
self.event_source_arn = spec["EventSourceArn"]
# optional
self.batch_size = spec.get("BatchSize")
self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON")
self.enabled = spec.get("Enabled", True)
self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None)
self.function_arn = spec["FunctionArn"]
self.uuid = str(uuid.uuid4())
self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())
# BatchSize service default/max mapping
batch_size_map = {
def _get_service_source_from_arn(self, event_source_arn):
return event_source_arn.split(":")[2].lower()
def _validate_event_source(self, event_source_arn):
valid_services = ("dynamodb", "kinesis", "sqs")
service = self._get_service_source_from_arn(event_source_arn)
return True if service in valid_services else False
@property
def event_source_arn(self):
return self._event_source_arn
@event_source_arn.setter
def event_source_arn(self, event_source_arn):
if not self._validate_event_source(event_source_arn):
raise ValueError(
"InvalidParameterValueException", "Unsupported event source type"
)
self._event_source_arn = event_source_arn
@property
def batch_size(self):
return self._batch_size
@batch_size.setter
def batch_size(self, batch_size):
batch_size_service_map = {
"kinesis": (100, 10000),
"dynamodb": (100, 1000),
"sqs": (10, 10),
}
source_type = self.event_source_arn.split(":")[2].lower()
batch_size_entry = batch_size_map.get(source_type)
if batch_size_entry:
# Use service default if not provided
batch_size = int(spec.get("BatchSize", batch_size_entry[0]))
if batch_size > batch_size_entry[1]:
raise ValueError(
"InvalidParameterValueException",
"BatchSize {} exceeds the max of {}".format(
batch_size, batch_size_entry[1]
),
)
else:
self.batch_size = batch_size
else:
raise ValueError(
"InvalidParameterValueException", "Unsupported event source type"
)
# optional
self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON")
self.enabled = spec.get("Enabled", True)
self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None)
source_type = self._get_service_source_from_arn(self.event_source_arn)
batch_size_for_source = batch_size_service_map[source_type]
if batch_size is None:
self._batch_size = batch_size_for_source[0]
elif batch_size > batch_size_for_source[1]:
error_message = "BatchSize {} exceeds the max of {}".format(
batch_size, batch_size_for_source[1]
)
raise ValueError("InvalidParameterValueException", error_message)
else:
self._batch_size = int(batch_size)
def get_configuration(self):
return {
@ -595,32 +645,72 @@ class EventSourceMapping(BaseModel):
"StateTransitionReason": "User initiated",
}
def delete(self, region_name):
lambda_backend = lambda_backends[region_name]
lambda_backend.delete_event_source_mapping(self.uuid)
@staticmethod
def cloudformation_name_type():
return None
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-lambda-eventsourcemapping.html
return "AWS::Lambda::EventSourceMapping"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
func = lambda_backends[region_name].get_function(properties["FunctionName"])
spec = {
"FunctionArn": func.function_arn,
"EventSourceArn": properties["EventSourceArn"],
"StartingPosition": properties["StartingPosition"],
"BatchSize": properties.get("BatchSize", 100),
}
optional_properties = "BatchSize Enabled StartingPositionTimestamp".split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]
return EventSourceMapping(spec)
lambda_backend = lambda_backends[region_name]
return lambda_backend.create_event_source_mapping(properties)
@classmethod
def update_from_cloudformation_json(
cls, new_resource_name, cloudformation_json, original_resource, region_name
):
properties = cloudformation_json["Properties"]
event_source_uuid = original_resource.uuid
lambda_backend = lambda_backends[region_name]
return lambda_backend.update_event_source_mapping(event_source_uuid, properties)
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
lambda_backend = lambda_backends[region_name]
esms = lambda_backend.list_event_source_mappings(
event_source_arn=properties["EventSourceArn"],
function_name=properties["FunctionName"],
)
for esm in esms:
if esm.uuid == resource_name:
esm.delete(region_name)
@property
def physical_resource_id(self):
return self.uuid
class LambdaVersion(BaseModel):
class LambdaVersion(CloudFormationModel):
def __init__(self, spec):
self.version = spec["Version"]
def __repr__(self):
return str(self.logical_resource_id)
@staticmethod
def cloudformation_name_type():
return None
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-lambda-version.html
return "AWS::Lambda::Version"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
@ -812,7 +902,7 @@ class LambdaBackend(BaseBackend):
)
# Validate function name
func = self._lambdas.get_function_by_name_or_arn(spec.pop("FunctionName", ""))
func = self._lambdas.get_function_by_name_or_arn(spec.get("FunctionName", ""))
if not func:
raise RESTError("ResourceNotFoundException", "Invalid FunctionName")
@ -870,18 +960,20 @@ class LambdaBackend(BaseBackend):
def update_event_source_mapping(self, uuid, spec):
esm = self.get_event_source_mapping(uuid)
if esm:
if spec.get("FunctionName"):
func = self._lambdas.get_function_by_name_or_arn(
spec.get("FunctionName")
)
if not esm:
return False
for key, value in spec.items():
if key == "FunctionName":
func = self._lambdas.get_function_by_name_or_arn(spec[key])
esm.function_arn = func.function_arn
if "BatchSize" in spec:
esm.batch_size = spec["BatchSize"]
if "Enabled" in spec:
esm.enabled = spec["Enabled"]
return esm
return False
elif key == "BatchSize":
esm.batch_size = spec[key]
elif key == "Enabled":
esm.enabled = spec[key]
esm.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())
return esm
def list_event_source_mappings(self, event_source_arn, function_name):
esms = list(self._event_source_mappings.values())
@ -981,7 +1073,29 @@ class LambdaBackend(BaseBackend):
]
}
func = self._lambdas.get_arn(function_arn)
func.invoke(json.dumps(event), {}, {})
return func.invoke(json.dumps(event), {}, {})
def send_log_event(
self, function_arn, filter_name, log_group_name, log_stream_name, log_events
):
data = {
"messageType": "DATA_MESSAGE",
"owner": ACCOUNT_ID,
"logGroup": log_group_name,
"logStream": log_stream_name,
"subscriptionFilters": [filter_name],
"logEvents": log_events,
}
output = io.BytesIO()
with GzipFile(fileobj=output, mode="w") as f:
f.write(json.dumps(data, separators=(",", ":")).encode("utf-8"))
payload_gz_encoded = base64.b64encode(output.getvalue()).decode("utf-8")
event = {"awslogs": {"data": payload_gz_encoded}}
func = self._lambdas.get_arn(function_arn)
return func.invoke(json.dumps(event), {}, {})
def list_tags(self, resource):
return self.get_function_by_arn(resource).tags
@ -1006,11 +1120,11 @@ class LambdaBackend(BaseBackend):
return True
return False
def add_policy_statement(self, function_name, raw):
def add_permission(self, function_name, raw):
fn = self.get_function(function_name)
fn.policy.add_statement(raw)
def del_policy_statement(self, function_name, sid, revision=""):
def remove_permission(self, function_name, sid, revision=""):
fn = self.get_function(function_name)
fn.policy.del_statement(sid, revision)
@ -1044,9 +1158,23 @@ class LambdaBackend(BaseBackend):
if fn:
payload = fn.invoke(body, headers, response_headers)
response_headers["Content-Length"] = str(len(payload))
return response_headers, payload
return payload
else:
return response_headers, None
return None
def put_function_concurrency(self, function_name, reserved_concurrency):
fn = self.get_function(function_name)
fn.reserved_concurrency = reserved_concurrency
return fn.reserved_concurrency
def delete_function_concurrency(self, function_name):
fn = self.get_function(function_name)
fn.reserved_concurrency = None
return fn.reserved_concurrency
def get_function_concurrency(self, function_name):
fn = self.get_function(function_name)
return fn.reserved_concurrency
def do_validate_s3():

View File

@ -141,12 +141,25 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def function_concurrency(self, request, full_url, headers):
http_method = request.method
self.setup_class(request, full_url, headers)
if http_method == "GET":
return self._get_function_concurrency(request)
elif http_method == "DELETE":
return self._delete_function_concurrency(request)
elif http_method == "PUT":
return self._put_function_concurrency(request)
else:
raise ValueError("Cannot handle request")
def _add_policy(self, request, full_url, headers):
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name):
statement = self.body
self.lambda_backend.add_policy_statement(function_name, statement)
self.lambda_backend.add_permission(function_name, statement)
return 200, {}, json.dumps({"Statement": statement})
else:
return 404, {}, "{}"
@ -166,9 +179,7 @@ class LambdaResponse(BaseResponse):
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
)
self.lambda_backend.remove_permission(function_name, statement_id, revision)
return 204, {}, "{}"
else:
return 404, {}, "{}"
@ -180,11 +191,19 @@ class LambdaResponse(BaseResponse):
function_name = unquote(self.path.rsplit("/", 2)[-2])
qualifier = self._get_param("qualifier")
response_header, payload = self.lambda_backend.invoke(
payload = self.lambda_backend.invoke(
function_name, qualifier, self.body, self.headers, response_headers
)
if payload:
return 202, response_headers, payload
if request.headers.get("X-Amz-Invocation-Type") == "Event":
status_code = 202
elif request.headers.get("X-Amz-Invocation-Type") == "DryRun":
status_code = 204
else:
if request.headers.get("X-Amz-Log-Type") != "Tail":
del response_headers["x-amz-log-result"]
status_code = 200
return status_code, response_headers, payload
else:
return 404, response_headers, "{}"
@ -295,7 +314,7 @@ class LambdaResponse(BaseResponse):
code["Configuration"]["FunctionArn"] += ":$LATEST"
return 200, {}, json.dumps(code)
else:
return 404, {}, "{}"
return 404, {"x-amzn-ErrorType": "ResourceNotFoundException"}, "{}"
def _get_aws_region(self, full_url):
region = self.region_regex.search(full_url)
@ -353,3 +372,38 @@ class LambdaResponse(BaseResponse):
return 200, {}, json.dumps(resp)
else:
return 404, {}, "{}"
def _get_function_concurrency(self, request):
path_function_name = self.path.rsplit("/", 2)[-2]
function_name = self.lambda_backend.get_function(path_function_name)
if function_name is None:
return 404, {}, "{}"
resp = self.lambda_backend.get_function_concurrency(path_function_name)
return 200, {}, json.dumps({"ReservedConcurrentExecutions": resp})
def _delete_function_concurrency(self, request):
path_function_name = self.path.rsplit("/", 2)[-2]
function_name = self.lambda_backend.get_function(path_function_name)
if function_name is None:
return 404, {}, "{}"
self.lambda_backend.delete_function_concurrency(path_function_name)
return 204, {}, "{}"
def _put_function_concurrency(self, request):
path_function_name = self.path.rsplit("/", 2)[-2]
function = self.lambda_backend.get_function(path_function_name)
if function is None:
return 404, {}, "{}"
concurrency = self._get_param("ReservedConcurrentExecutions", None)
resp = self.lambda_backend.put_function_concurrency(
path_function_name, concurrency
)
return 200, {}, json.dumps({"ReservedConcurrentExecutions": resp})

View File

@ -19,4 +19,5 @@ url_paths = {
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/concurrency/?$": response.function_concurrency,
}

View File

@ -1,122 +1,113 @@
from __future__ import unicode_literals
from moto.acm import acm_backends
from moto.apigateway import apigateway_backends
from moto.athena import athena_backends
from moto.autoscaling import autoscaling_backends
from moto.awslambda import lambda_backends
from moto.batch import batch_backends
from moto.cloudformation import cloudformation_backends
from moto.cloudwatch import cloudwatch_backends
from moto.codecommit import codecommit_backends
from moto.codepipeline import codepipeline_backends
from moto.cognitoidentity import cognitoidentity_backends
from moto.cognitoidp import cognitoidp_backends
from moto.config import config_backends
from moto.core import moto_api_backends
from moto.datapipeline import datapipeline_backends
from moto.datasync import datasync_backends
from moto.dynamodb import dynamodb_backends
from moto.dynamodb2 import dynamodb_backends2
from moto.dynamodbstreams import dynamodbstreams_backends
from moto.ec2 import ec2_backends
from moto.ec2_instance_connect import ec2_instance_connect_backends
from moto.ecr import ecr_backends
from moto.ecs import ecs_backends
from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
from moto.emr import emr_backends
from moto.events import events_backends
from moto.glacier import glacier_backends
from moto.glue import glue_backends
from moto.iam import iam_backends
from moto.instance_metadata import instance_metadata_backends
from moto.iot import iot_backends
from moto.iotdata import iotdata_backends
from moto.kinesis import kinesis_backends
from moto.kms import kms_backends
from moto.logs import logs_backends
from moto.opsworks import opsworks_backends
from moto.organizations import organizations_backends
from moto.polly import polly_backends
from moto.rds2 import rds2_backends
from moto.redshift import redshift_backends
from moto.resourcegroups import resourcegroups_backends
from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends
from moto.route53 import route53_backends
from moto.s3 import s3_backends
from moto.secretsmanager import secretsmanager_backends
from moto.ses import ses_backends
from moto.sns import sns_backends
from moto.sqs import sqs_backends
from moto.ssm import ssm_backends
from moto.stepfunctions import stepfunction_backends
from moto.sts import sts_backends
from moto.swf import swf_backends
from moto.xray import xray_backends
import importlib
BACKENDS = {
"acm": acm_backends,
"apigateway": apigateway_backends,
"athena": athena_backends,
"autoscaling": autoscaling_backends,
"batch": batch_backends,
"cloudformation": cloudformation_backends,
"cloudwatch": cloudwatch_backends,
"codecommit": codecommit_backends,
"codepipeline": codepipeline_backends,
"cognito-identity": cognitoidentity_backends,
"cognito-idp": cognitoidp_backends,
"config": config_backends,
"datapipeline": datapipeline_backends,
"datasync": datasync_backends,
"dynamodb": dynamodb_backends,
"dynamodb2": dynamodb_backends2,
"dynamodbstreams": dynamodbstreams_backends,
"ec2": ec2_backends,
"ec2_instance_connect": ec2_instance_connect_backends,
"ecr": ecr_backends,
"ecs": ecs_backends,
"elb": elb_backends,
"elbv2": elbv2_backends,
"events": events_backends,
"emr": emr_backends,
"glacier": glacier_backends,
"glue": glue_backends,
"iam": iam_backends,
"moto_api": moto_api_backends,
"instance_metadata": instance_metadata_backends,
"logs": logs_backends,
"kinesis": kinesis_backends,
"kms": kms_backends,
"opsworks": opsworks_backends,
"organizations": organizations_backends,
"polly": polly_backends,
"redshift": redshift_backends,
"resource-groups": resourcegroups_backends,
"rds": rds2_backends,
"s3": s3_backends,
"s3bucket_path": s3_backends,
"ses": ses_backends,
"secretsmanager": secretsmanager_backends,
"sns": sns_backends,
"sqs": sqs_backends,
"ssm": ssm_backends,
"stepfunctions": stepfunction_backends,
"sts": sts_backends,
"swf": swf_backends,
"route53": route53_backends,
"lambda": lambda_backends,
"xray": xray_backends,
"resourcegroupstaggingapi": resourcegroupstaggingapi_backends,
"iot": iot_backends,
"iot-data": iotdata_backends,
"acm": ("acm", "acm_backends"),
"apigateway": ("apigateway", "apigateway_backends"),
"athena": ("athena", "athena_backends"),
"applicationautoscaling": (
"applicationautoscaling",
"applicationautoscaling_backends",
),
"autoscaling": ("autoscaling", "autoscaling_backends"),
"batch": ("batch", "batch_backends"),
"cloudformation": ("cloudformation", "cloudformation_backends"),
"cloudwatch": ("cloudwatch", "cloudwatch_backends"),
"codecommit": ("codecommit", "codecommit_backends"),
"codepipeline": ("codepipeline", "codepipeline_backends"),
"cognito-identity": ("cognitoidentity", "cognitoidentity_backends"),
"cognito-idp": ("cognitoidp", "cognitoidp_backends"),
"config": ("config", "config_backends"),
"datapipeline": ("datapipeline", "datapipeline_backends"),
"datasync": ("datasync", "datasync_backends"),
"dynamodb": ("dynamodb", "dynamodb_backends"),
"dynamodb2": ("dynamodb2", "dynamodb_backends2"),
"dynamodbstreams": ("dynamodbstreams", "dynamodbstreams_backends"),
"ec2": ("ec2", "ec2_backends"),
"ec2instanceconnect": ("ec2instanceconnect", "ec2instanceconnect_backends"),
"ecr": ("ecr", "ecr_backends"),
"ecs": ("ecs", "ecs_backends"),
"elasticbeanstalk": ("elasticbeanstalk", "eb_backends"),
"elb": ("elb", "elb_backends"),
"elbv2": ("elbv2", "elbv2_backends"),
"emr": ("emr", "emr_backends"),
"events": ("events", "events_backends"),
"glacier": ("glacier", "glacier_backends"),
"glue": ("glue", "glue_backends"),
"iam": ("iam", "iam_backends"),
"instance_metadata": ("instance_metadata", "instance_metadata_backends"),
"iot": ("iot", "iot_backends"),
"iot-data": ("iotdata", "iotdata_backends"),
"kinesis": ("kinesis", "kinesis_backends"),
"kms": ("kms", "kms_backends"),
"lambda": ("awslambda", "lambda_backends"),
"logs": ("logs", "logs_backends"),
"managedblockchain": ("managedblockchain", "managedblockchain_backends"),
"moto_api": ("core", "moto_api_backends"),
"opsworks": ("opsworks", "opsworks_backends"),
"organizations": ("organizations", "organizations_backends"),
"polly": ("polly", "polly_backends"),
"ram": ("ram", "ram_backends"),
"rds": ("rds2", "rds2_backends"),
"redshift": ("redshift", "redshift_backends"),
"resource-groups": ("resourcegroups", "resourcegroups_backends"),
"resourcegroupstaggingapi": (
"resourcegroupstaggingapi",
"resourcegroupstaggingapi_backends",
),
"route53": ("route53", "route53_backends"),
"s3": ("s3", "s3_backends"),
"s3bucket_path": ("s3", "s3_backends"),
"sagemaker": ("sagemaker", "sagemaker_backends"),
"secretsmanager": ("secretsmanager", "secretsmanager_backends"),
"ses": ("ses", "ses_backends"),
"sns": ("sns", "sns_backends"),
"sqs": ("sqs", "sqs_backends"),
"ssm": ("ssm", "ssm_backends"),
"stepfunctions": ("stepfunctions", "stepfunction_backends"),
"sts": ("sts", "sts_backends"),
"swf": ("swf", "swf_backends"),
"transcribe": ("transcribe", "transcribe_backends"),
"xray": ("xray", "xray_backends"),
"kinesisvideo": ("kinesisvideo", "kinesisvideo_backends"),
"kinesis-video-archived-media": (
"kinesisvideoarchivedmedia",
"kinesisvideoarchivedmedia_backends",
),
"forecast": ("forecast", "forecast_backends"),
}
def _import_backend(module_name, backends_name):
module = importlib.import_module("moto." + module_name)
return getattr(module, backends_name)
def backends():
for module_name, backends_name in BACKENDS.values():
yield _import_backend(module_name, backends_name)
def named_backends():
for name, (module_name, backends_name) in BACKENDS.items():
yield name, _import_backend(module_name, backends_name)
def get_backend(name):
module_name, backends_name = BACKENDS[name]
return _import_backend(module_name, backends_name)
def search_backend(predicate):
for name, backend in named_backends():
if predicate(backend):
return name
def get_model(name, region_name):
for backends in BACKENDS.values():
for region, backend in backends.items():
for backends_ in backends():
for region, backend in backends_.items():
if region == region_name:
models = getattr(backend.__class__, "__models__", {})
if name in models:

View File

@ -1,40 +1,24 @@
from __future__ import unicode_literals
import json
class AWSError(Exception):
CODE = None
STATUS = 400
def __init__(self, message, code=None, status=None):
self.message = message
self.code = code if code is not None else self.CODE
self.status = status if status is not None else self.STATUS
def response(self):
return (
json.dumps({"__type": self.code, "message": self.message}),
dict(status=self.status),
)
from moto.core.exceptions import AWSError
class InvalidRequestException(AWSError):
CODE = "InvalidRequestException"
TYPE = "InvalidRequestException"
class InvalidParameterValueException(AWSError):
CODE = "InvalidParameterValue"
TYPE = "InvalidParameterValue"
class ValidationError(AWSError):
CODE = "ValidationError"
TYPE = "ValidationError"
class InternalFailure(AWSError):
CODE = "InternalFailure"
TYPE = "InternalFailure"
STATUS = 500
class ClientException(AWSError):
CODE = "ClientException"
TYPE = "ClientException"
STATUS = 400

View File

@ -1,6 +1,5 @@
from __future__ import unicode_literals
import re
import requests.adapters
from itertools import cycle
import six
import datetime
@ -8,12 +7,11 @@ import time
import uuid
import logging
import docker
import functools
import threading
import dateutil.parser
from boto3 import Session
from moto.core import BaseBackend, BaseModel
from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.iam import iam_backends
from moto.ec2 import ec2_backends
from moto.ecs import ecs_backends
@ -30,8 +28,8 @@ from moto.ec2.exceptions import InvalidSubnetIdError
from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES
from moto.iam.exceptions import IAMNotFoundException
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
from moto.utilities.docker_utilities import DockerModel
_orig_adapter_send = requests.adapters.HTTPAdapter.send
logger = logging.getLogger(__name__)
COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(
r"^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$"
@ -42,7 +40,7 @@ def datetime2int(date):
return int(time.mktime(date.timetuple()))
class ComputeEnvironment(BaseModel):
class ComputeEnvironment(CloudFormationModel):
def __init__(
self,
compute_environment_name,
@ -76,6 +74,15 @@ class ComputeEnvironment(BaseModel):
def physical_resource_id(self):
return self.arn
@staticmethod
def cloudformation_name_type():
return "ComputeEnvironmentName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-computeenvironment.html
return "AWS::Batch::ComputeEnvironment"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
@ -95,7 +102,7 @@ class ComputeEnvironment(BaseModel):
return backend.get_compute_environment_by_arn(arn)
class JobQueue(BaseModel):
class JobQueue(CloudFormationModel):
def __init__(
self, name, priority, state, environments, env_order_json, region_name
):
@ -139,6 +146,15 @@ class JobQueue(BaseModel):
def physical_resource_id(self):
return self.arn
@staticmethod
def cloudformation_name_type():
return "JobQueueName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobqueue.html
return "AWS::Batch::JobQueue"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
@ -164,7 +180,7 @@ class JobQueue(BaseModel):
return backend.get_job_queue_by_arn(arn)
class JobDefinition(BaseModel):
class JobDefinition(CloudFormationModel):
def __init__(
self,
name,
@ -264,6 +280,15 @@ class JobDefinition(BaseModel):
def physical_resource_id(self):
return self.arn
@staticmethod
def cloudformation_name_type():
return "JobDefinitionName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobdefinition.html
return "AWS::Batch::JobDefinition"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
@ -284,7 +309,7 @@ class JobDefinition(BaseModel):
return backend.get_job_definition_by_arn(arn)
class Job(threading.Thread, BaseModel):
class Job(threading.Thread, BaseModel, DockerModel):
def __init__(self, name, job_def, job_queue, log_backend, container_overrides):
"""
Docker Job
@ -297,11 +322,12 @@ class Job(threading.Thread, BaseModel):
:type log_backend: moto.logs.models.LogsBackend
"""
threading.Thread.__init__(self)
DockerModel.__init__(self)
self.job_name = name
self.job_id = str(uuid.uuid4())
self.job_definition = job_def
self.container_overrides = container_overrides
self.container_overrides = container_overrides or {}
self.job_queue = job_queue
self.job_state = "SUBMITTED" # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED
self.job_queue.jobs.append(self)
@ -315,22 +341,8 @@ class Job(threading.Thread, BaseModel):
self.daemon = True
self.name = "MOTO-BATCH-" + self.job_id
self.docker_client = docker.from_env()
self._log_backend = log_backend
# Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked
if requests.adapters.HTTPAdapter.send != _orig_adapter_send:
_orig_get_adapter = self.docker_client.api.get_adapter
def replace_adapter_send(*args, **kwargs):
adapter = _orig_get_adapter(*args, **kwargs)
if isinstance(adapter, requests.adapters.HTTPAdapter):
adapter.send = functools.partial(_orig_adapter_send, adapter)
return adapter
self.docker_client.api.get_adapter = replace_adapter_send
self.log_stream_name = None
def describe(self):
result = {
@ -338,10 +350,11 @@ class Job(threading.Thread, BaseModel):
"jobId": self.job_id,
"jobName": self.job_name,
"jobQueue": self.job_queue.arn,
"startedAt": datetime2int(self.job_started_at),
"status": self.job_state,
"dependsOn": [],
}
if result["status"] not in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING"]:
result["startedAt"] = datetime2int(self.job_started_at)
if self.job_stopped:
result["stoppedAt"] = datetime2int(self.job_stopped_at)
result["container"] = {}
@ -379,7 +392,6 @@ class Job(threading.Thread, BaseModel):
"""
try:
self.job_state = "PENDING"
time.sleep(1)
image = self.job_definition.container_properties.get(
"image", "alpine:latest"
@ -412,8 +424,8 @@ class Job(threading.Thread, BaseModel):
self.job_state = "RUNNABLE"
# TODO setup ecs container instance
time.sleep(1)
self.job_started_at = datetime.datetime.now()
self.job_state = "STARTING"
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
container = self.docker_client.containers.run(
@ -427,58 +439,24 @@ class Job(threading.Thread, BaseModel):
privileged=privileged,
)
self.job_state = "RUNNING"
self.job_started_at = datetime.datetime.now()
try:
# Log collection
logs_stdout = []
logs_stderr = []
container.reload()
# Dodgy hack, we can only check docker logs once a second, but we want to loop more
# so we can stop if asked to in a quick manner, should all go away if we go async
# There also be some dodgyness when sending an integer to docker logs and some
# events seem to be duplicated.
now = datetime.datetime.now()
i = 1
while container.status == "running" and not self.stop:
time.sleep(0.15)
if i % 10 == 0:
logs_stderr.extend(
container.logs(
stdout=False,
stderr=True,
timestamps=True,
since=datetime2int(now),
)
.decode()
.split("\n")
)
logs_stdout.extend(
container.logs(
stdout=True,
stderr=False,
timestamps=True,
since=datetime2int(now),
)
.decode()
.split("\n")
)
now = datetime.datetime.now()
container.reload()
i += 1
container.reload()
# Container should be stopped by this point... unless asked to stop
if container.status == "running":
container.kill()
self.job_stopped_at = datetime.datetime.now()
# Get final logs
# Log collection
logs_stdout = []
logs_stderr = []
logs_stderr.extend(
container.logs(
stdout=False,
stderr=True,
timestamps=True,
since=datetime2int(now),
since=datetime2int(self.job_started_at),
)
.decode()
.split("\n")
@ -488,14 +466,12 @@ class Job(threading.Thread, BaseModel):
stdout=True,
stderr=False,
timestamps=True,
since=datetime2int(now),
since=datetime2int(self.job_started_at),
)
.decode()
.split("\n")
)
self.job_state = "SUCCEEDED" if not self.stop else "FAILED"
# Process logs
logs_stdout = [x for x in logs_stdout if len(x) > 0]
logs_stderr = [x for x in logs_stderr if len(x) > 0]
@ -503,7 +479,10 @@ class Job(threading.Thread, BaseModel):
for line in logs_stdout + logs_stderr:
date, line = line.split(" ", 1)
date = dateutil.parser.parse(date)
date = int(date.timestamp())
# TODO: Replace with int(date.timestamp()) once we yeet Python2 out of the window
date = int(
(time.mktime(date.timetuple()) + date.microsecond / 1000000.0)
)
logs.append({"timestamp": date, "message": line.strip()})
# Send to cloudwatch
@ -516,6 +495,8 @@ class Job(threading.Thread, BaseModel):
self._log_backend.create_log_stream(log_group, stream_name)
self._log_backend.put_log_events(log_group, stream_name, logs, None)
self.job_state = "SUCCEEDED" if not self.stop else "FAILED"
except Exception as err:
logger.error(
"Failed to run AWS Batch container {0}. Error {1}".format(

View File

@ -21,6 +21,14 @@ def lowercase_first_key(some_dict):
new_dict = {}
for key, value in some_dict.items():
new_key = key[0].lower() + key[1:]
new_dict[new_key] = value
try:
if isinstance(value, dict):
new_dict[new_key] = lowercase_first_key(value)
elif all([isinstance(v, dict) for v in value]):
new_dict[new_key] = [lowercase_first_key(v) for v in value]
else:
new_dict[new_key] = value
except TypeError:
new_dict[new_key] = value
return new_dict

View File

@ -8,6 +8,7 @@ from boto3 import Session
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .parsing import ResourceMap, OutputMap
from .utils import (
@ -218,7 +219,12 @@ class FakeStack(BaseModel):
self.stack_id = stack_id
self.name = name
self.template = template
self._parse_template()
if template != {}:
self._parse_template()
self.description = self.template_dict.get("Description")
else:
self.template_dict = {}
self.description = None
self.parameters = parameters
self.region_name = region_name
self.notification_arns = notification_arns if notification_arns else []
@ -234,12 +240,16 @@ class FakeStack(BaseModel):
"CREATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.description = self.template_dict.get("Description")
self.cross_stack_resources = cross_stack_resources or {}
self.resource_map = self._create_resource_map()
self.output_map = self._create_output_map()
self._add_stack_event("CREATE_COMPLETE")
self.status = "CREATE_COMPLETE"
if create_change_set:
self.status = "CREATE_COMPLETE"
self.execution_status = "AVAILABLE"
else:
self.create_resources()
self._add_stack_event("CREATE_COMPLETE")
self.creation_time = datetime.utcnow()
def _create_resource_map(self):
resource_map = ResourceMap(
@ -251,7 +261,7 @@ class FakeStack(BaseModel):
self.template_dict,
self.cross_stack_resources,
)
resource_map.create()
resource_map.load()
return resource_map
def _create_output_map(self):
@ -259,6 +269,10 @@ class FakeStack(BaseModel):
output_map.create()
return output_map
@property
def creation_time_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.creation_time)
def _add_stack_event(
self, resource_status, resource_status_reason=None, resource_properties=None
):
@ -301,8 +315,8 @@ class FakeStack(BaseModel):
yaml.add_multi_constructor("", yaml_tag_constructor)
try:
self.template_dict = yaml.load(self.template, Loader=yaml.Loader)
except yaml.parser.ParserError:
self.template_dict = json.loads(self.template, Loader=yaml.Loader)
except (yaml.parser.ParserError, yaml.scanner.ScannerError):
self.template_dict = json.loads(self.template)
@property
def stack_parameters(self):
@ -320,6 +334,12 @@ class FakeStack(BaseModel):
def exports(self):
return self.output_map.exports
def create_resources(self):
self.resource_map.create(self.template_dict)
# Set the description of the stack
self.description = self.template_dict.get("Description")
self.status = "CREATE_COMPLETE"
def update(self, template, role_arn=None, parameters=None, tags=None):
self._add_stack_event(
"UPDATE_IN_PROGRESS", resource_status_reason="User Initiated"
@ -384,6 +404,9 @@ class FakeChangeSet(FakeStack):
self.change_set_id = change_set_id
self.change_set_name = change_set_name
self.changes = self.diff(template=template, parameters=parameters)
if self.description is None:
self.description = self.template_dict.get("Description")
self.creation_time = datetime.utcnow()
def diff(self, template, parameters=None):
self.template = template
@ -426,6 +449,16 @@ class FakeEvent(BaseModel):
self.event_id = uuid.uuid4()
def filter_stacks(all_stacks, status_filter):
filtered_stacks = []
if not status_filter:
return all_stacks
for stack in all_stacks:
if stack.status in status_filter:
filtered_stacks.append(stack)
return filtered_stacks
class CloudFormationBackend(BaseBackend):
def __init__(self):
self.stacks = OrderedDict()
@ -574,8 +607,8 @@ class CloudFormationBackend(BaseBackend):
if stack is None:
raise ValidationError(stack_name)
else:
stack_id = generate_stack_id(stack_name)
stack_template = template
stack_id = generate_stack_id(stack_name, region_name)
stack_template = {}
change_set_id = generate_changeset_id(change_set_name, region_name)
new_change_set = FakeChangeSet(
@ -630,10 +663,14 @@ class CloudFormationBackend(BaseBackend):
if stack is None:
raise ValidationError(stack_name)
if stack.events[-1].resource_status == "REVIEW_IN_PROGRESS":
stack._add_stack_event(
"CREATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
stack._add_stack_event("CREATE_COMPLETE")
else:
stack._add_stack_event("UPDATE_IN_PROGRESS")
stack._add_stack_event("UPDATE_COMPLETE")
stack.create_resources()
return True
def describe_stacks(self, name_or_stack_id):
@ -654,10 +691,11 @@ class CloudFormationBackend(BaseBackend):
def list_change_sets(self):
return self.change_sets.values()
def list_stacks(self):
return [v for v in self.stacks.values()] + [
def list_stacks(self, status_filter=None):
total_stacks = [v for v in self.stacks.values()] + [
v for v in self.deleted_stacks.values()
]
return filter_stacks(total_stacks, status_filter)
def get_stack(self, name_or_stack_id):
all_stacks = dict(self.deleted_stacks, **self.stacks)

View File

@ -1,33 +1,48 @@
from __future__ import unicode_literals
import functools
import json
import logging
import copy
import warnings
import re
from moto.autoscaling import models as autoscaling_models
from moto.awslambda import models as lambda_models
from moto.batch import models as batch_models
from moto.cloudwatch import models as cloudwatch_models
from moto.cognitoidentity import models as cognitoidentity_models
from moto.compat import collections_abc
from moto.datapipeline import models as datapipeline_models
from moto.dynamodb2 import models as dynamodb2_models
# This ugly section of imports is necessary because we
# build the list of CloudFormationModel subclasses using
# CloudFormationModel.__subclasses__(). However, if the class
# definition of a subclass hasn't been executed yet - for example, if
# the subclass's module hasn't been imported yet - then that subclass
# doesn't exist yet, and __subclasses__ won't find it.
# So we import here to populate the list of subclasses.
from moto.autoscaling import models as autoscaling_models # noqa
from moto.awslambda import models as awslambda_models # noqa
from moto.batch import models as batch_models # noqa
from moto.cloudwatch import models as cloudwatch_models # noqa
from moto.datapipeline import models as datapipeline_models # noqa
from moto.dynamodb2 import models as dynamodb2_models # noqa
from moto.ec2 import models as ec2_models
from moto.ecs import models as ecs_models
from moto.elb import models as elb_models
from moto.elbv2 import models as elbv2_models
from moto.iam import models as iam_models
from moto.kinesis import models as kinesis_models
from moto.kms import models as kms_models
from moto.rds import models as rds_models
from moto.rds2 import models as rds2_models
from moto.redshift import models as redshift_models
from moto.route53 import models as route53_models
from moto.s3 import models as s3_models
from moto.sns import models as sns_models
from moto.sqs import models as sqs_models
from moto.core import ACCOUNT_ID
from moto.ecr import models as ecr_models # noqa
from moto.ecs import models as ecs_models # noqa
from moto.elb import models as elb_models # noqa
from moto.elbv2 import models as elbv2_models # noqa
from moto.events import models as events_models # noqa
from moto.iam import models as iam_models # noqa
from moto.kinesis import models as kinesis_models # noqa
from moto.kms import models as kms_models # noqa
from moto.rds import models as rds_models # noqa
from moto.rds2 import models as rds2_models # noqa
from moto.redshift import models as redshift_models # noqa
from moto.route53 import models as route53_models # noqa
from moto.s3 import models as s3_models, s3_backend # noqa
from moto.s3.utils import bucket_and_name_from_url
from moto.sns import models as sns_models # noqa
from moto.sqs import models as sqs_models # noqa
from moto.stepfunctions import models as stepfunctions_models # noqa
# End ugly list of imports
from moto.core import ACCOUNT_ID, CloudFormationModel
from .utils import random_suffix
from .exceptions import (
ExportNotFound,
@ -35,78 +50,14 @@ from .exceptions import (
UnformattedGetAttTemplateException,
ValidationError,
)
from boto.cloudformation.stack import Output
from moto.packages.boto.cloudformation.stack import Output
MODEL_MAP = {
"AWS::AutoScaling::AutoScalingGroup": autoscaling_models.FakeAutoScalingGroup,
"AWS::AutoScaling::LaunchConfiguration": autoscaling_models.FakeLaunchConfiguration,
"AWS::Batch::JobDefinition": batch_models.JobDefinition,
"AWS::Batch::JobQueue": batch_models.JobQueue,
"AWS::Batch::ComputeEnvironment": batch_models.ComputeEnvironment,
"AWS::DynamoDB::Table": dynamodb2_models.Table,
"AWS::Kinesis::Stream": kinesis_models.Stream,
"AWS::Lambda::EventSourceMapping": lambda_models.EventSourceMapping,
"AWS::Lambda::Function": lambda_models.LambdaFunction,
"AWS::Lambda::Version": lambda_models.LambdaVersion,
"AWS::EC2::EIP": ec2_models.ElasticAddress,
"AWS::EC2::Instance": ec2_models.Instance,
"AWS::EC2::InternetGateway": ec2_models.InternetGateway,
"AWS::EC2::NatGateway": ec2_models.NatGateway,
"AWS::EC2::NetworkInterface": ec2_models.NetworkInterface,
"AWS::EC2::Route": ec2_models.Route,
"AWS::EC2::RouteTable": ec2_models.RouteTable,
"AWS::EC2::SecurityGroup": ec2_models.SecurityGroup,
"AWS::EC2::SecurityGroupIngress": ec2_models.SecurityGroupIngress,
"AWS::EC2::SpotFleet": ec2_models.SpotFleetRequest,
"AWS::EC2::Subnet": ec2_models.Subnet,
"AWS::EC2::SubnetRouteTableAssociation": ec2_models.SubnetRouteTableAssociation,
"AWS::EC2::Volume": ec2_models.Volume,
"AWS::EC2::VolumeAttachment": ec2_models.VolumeAttachment,
"AWS::EC2::VPC": ec2_models.VPC,
"AWS::EC2::VPCGatewayAttachment": ec2_models.VPCGatewayAttachment,
"AWS::EC2::VPCPeeringConnection": ec2_models.VPCPeeringConnection,
"AWS::ECS::Cluster": ecs_models.Cluster,
"AWS::ECS::TaskDefinition": ecs_models.TaskDefinition,
"AWS::ECS::Service": ecs_models.Service,
"AWS::ElasticLoadBalancing::LoadBalancer": elb_models.FakeLoadBalancer,
"AWS::ElasticLoadBalancingV2::LoadBalancer": elbv2_models.FakeLoadBalancer,
"AWS::ElasticLoadBalancingV2::TargetGroup": elbv2_models.FakeTargetGroup,
"AWS::ElasticLoadBalancingV2::Listener": elbv2_models.FakeListener,
"AWS::Cognito::IdentityPool": cognitoidentity_models.CognitoIdentity,
"AWS::DataPipeline::Pipeline": datapipeline_models.Pipeline,
"AWS::IAM::InstanceProfile": iam_models.InstanceProfile,
"AWS::IAM::Role": iam_models.Role,
"AWS::KMS::Key": kms_models.Key,
"AWS::Logs::LogGroup": cloudwatch_models.LogGroup,
"AWS::RDS::DBInstance": rds_models.Database,
"AWS::RDS::DBSecurityGroup": rds_models.SecurityGroup,
"AWS::RDS::DBSubnetGroup": rds_models.SubnetGroup,
"AWS::RDS::DBParameterGroup": rds2_models.DBParameterGroup,
"AWS::Redshift::Cluster": redshift_models.Cluster,
"AWS::Redshift::ClusterParameterGroup": redshift_models.ParameterGroup,
"AWS::Redshift::ClusterSubnetGroup": redshift_models.SubnetGroup,
"AWS::Route53::HealthCheck": route53_models.HealthCheck,
"AWS::Route53::HostedZone": route53_models.FakeZone,
"AWS::Route53::RecordSet": route53_models.RecordSet,
"AWS::Route53::RecordSetGroup": route53_models.RecordSetGroup,
"AWS::SNS::Topic": sns_models.Topic,
"AWS::S3::Bucket": s3_models.FakeBucket,
"AWS::SQS::Queue": sqs_models.Queue,
}
# http://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-name.html
# List of supported CloudFormation models
MODEL_LIST = CloudFormationModel.__subclasses__()
MODEL_MAP = {model.cloudformation_type(): model for model in MODEL_LIST}
NAME_TYPE_MAP = {
"AWS::CloudWatch::Alarm": "Alarm",
"AWS::DynamoDB::Table": "TableName",
"AWS::ElastiCache::CacheCluster": "ClusterName",
"AWS::ElasticBeanstalk::Application": "ApplicationName",
"AWS::ElasticBeanstalk::Environment": "EnvironmentName",
"AWS::ElasticLoadBalancing::LoadBalancer": "LoadBalancerName",
"AWS::ElasticLoadBalancingV2::TargetGroup": "Name",
"AWS::RDS::DBInstance": "DBInstanceIdentifier",
"AWS::S3::Bucket": "BucketName",
"AWS::SNS::Topic": "TopicName",
"AWS::SQS::Queue": "QueueName",
model.cloudformation_type(): model.cloudformation_name_type()
for model in MODEL_LIST
}
# Just ignore these models types for now
@ -150,7 +101,10 @@ def clean_json(resource_json, resources_map):
map_path = resource_json["Fn::FindInMap"][1:]
result = resources_map[map_name]
for path in map_path:
result = result[clean_json(path, resources_map)]
if "Fn::Transform" in result:
result = resources_map[clean_json(path, resources_map)]
else:
result = result[clean_json(path, resources_map)]
return result
if "Fn::GetAtt" in resource_json:
@ -196,13 +150,13 @@ def clean_json(resource_json, resources_map):
)
else:
fn_sub_value = clean_json(resource_json["Fn::Sub"], resources_map)
to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value)
literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value)
to_sub = re.findall(r'(?=\${)[^!^"]*?}', fn_sub_value)
literals = re.findall(r'(?=\${!)[^"]*?}', fn_sub_value)
for sub in to_sub:
if "." in sub:
cleaned_ref = clean_json(
{
"Fn::GetAtt": re.findall('(?<=\${)[^"]*?(?=})', sub)[
"Fn::GetAtt": re.findall(r'(?<=\${)[^"]*?(?=})', sub)[
0
].split(".")
},
@ -210,7 +164,7 @@ def clean_json(resource_json, resources_map):
)
else:
cleaned_ref = clean_json(
{"Ref": re.findall('(?<=\${)[^"]*?(?=})', sub)[0]},
{"Ref": re.findall(r'(?<=\${)[^"]*?(?=})', sub)[0]},
resources_map,
)
fn_sub_value = fn_sub_value.replace(sub, cleaned_ref)
@ -261,10 +215,14 @@ def resource_class_from_type(resource_type):
if resource_type not in MODEL_MAP:
logger.warning("No Moto CloudFormation support for %s", resource_type)
return None
return MODEL_MAP.get(resource_type)
def resource_name_property_from_type(resource_type):
for model in MODEL_LIST:
if model.cloudformation_type() == resource_type:
return model.cloudformation_name_type()
return NAME_TYPE_MAP.get(resource_type)
@ -283,11 +241,21 @@ def generate_resource_name(resource_type, stack_name, logical_id):
if truncated_name_prefix.endswith("-"):
truncated_name_prefix = truncated_name_prefix[:-1]
return "{0}-{1}".format(truncated_name_prefix, my_random_suffix)
elif resource_type == "AWS::S3::Bucket":
right_hand_part_of_name = "-{0}-{1}".format(logical_id, random_suffix())
max_stack_name_portion_len = 63 - len(right_hand_part_of_name)
return "{0}{1}".format(
stack_name[:max_stack_name_portion_len], right_hand_part_of_name
).lower()
elif resource_type == "AWS::IAM::Policy":
return "{0}-{1}-{2}".format(stack_name[:5], logical_id[:4], random_suffix())
else:
return "{0}-{1}-{2}".format(stack_name, logical_id, random_suffix())
def parse_resource(logical_id, resource_json, resources_map):
def parse_resource(
resource_json, resources_map,
):
resource_type = resource_json["Type"]
resource_class = resource_class_from_type(resource_type)
if not resource_class:
@ -298,22 +266,37 @@ def parse_resource(logical_id, resource_json, resources_map):
)
return None
if "Properties" not in resource_json:
resource_json["Properties"] = {}
resource_json = clean_json(resource_json, resources_map)
return resource_class, resource_json, resource_type
def parse_resource_and_generate_name(
logical_id, resource_json, resources_map,
):
resource_tuple = parse_resource(resource_json, resources_map)
if not resource_tuple:
return None
resource_class, resource_json, resource_type = resource_tuple
generated_resource_name = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
resource_name_property = resource_name_property_from_type(resource_type)
if resource_name_property:
if "Properties" not in resource_json:
resource_json["Properties"] = dict()
if resource_name_property not in resource_json["Properties"]:
resource_json["Properties"][
resource_name_property
] = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
resource_name = resource_json["Properties"][resource_name_property]
if (
"Properties" in resource_json
and resource_name_property in resource_json["Properties"]
):
resource_name = resource_json["Properties"][resource_name_property]
else:
resource_name = generated_resource_name
else:
resource_name = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
resource_name = generated_resource_name
return resource_class, resource_json, resource_name
@ -325,12 +308,14 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n
return None
resource_type = resource_json["Type"]
resource_tuple = parse_resource(logical_id, resource_json, resources_map)
resource_tuple = parse_resource_and_generate_name(
logical_id, resource_json, resources_map
)
if not resource_tuple:
return None
resource_class, resource_json, resource_name = resource_tuple
resource_class, resource_json, resource_physical_name = resource_tuple
resource = resource_class.create_from_cloudformation_json(
resource_name, resource_json, region_name
resource_physical_name, resource_json, region_name
)
resource.type = resource_type
resource.logical_resource_id = logical_id
@ -338,28 +323,34 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n
def parse_and_update_resource(logical_id, resource_json, resources_map, region_name):
resource_class, new_resource_json, new_resource_name = parse_resource(
resource_class, resource_json, new_resource_name = parse_resource_and_generate_name(
logical_id, resource_json, resources_map
)
original_resource = resources_map[logical_id]
new_resource = resource_class.update_from_cloudformation_json(
original_resource=original_resource,
new_resource_name=new_resource_name,
cloudformation_json=new_resource_json,
region_name=region_name,
)
new_resource.type = resource_json["Type"]
new_resource.logical_resource_id = logical_id
return new_resource
if not hasattr(
resource_class.update_from_cloudformation_json, "__isabstractmethod__"
):
new_resource = resource_class.update_from_cloudformation_json(
original_resource=original_resource,
new_resource_name=new_resource_name,
cloudformation_json=resource_json,
region_name=region_name,
)
new_resource.type = resource_json["Type"]
new_resource.logical_resource_id = logical_id
return new_resource
else:
return None
def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name):
resource_class, resource_json, resource_name = parse_resource(
logical_id, resource_json, resources_map
)
resource_class.delete_from_cloudformation_json(
resource_name, resource_json, region_name
)
def parse_and_delete_resource(resource_name, resource_json, resources_map, region_name):
resource_class, resource_json, _ = parse_resource(resource_json, resources_map)
if not hasattr(
resource_class.delete_from_cloudformation_json, "__isabstractmethod__"
):
resource_class.delete_from_cloudformation_json(
resource_name, resource_json, region_name
)
def parse_condition(condition, resources_map, condition_map):
@ -423,7 +414,7 @@ class ResourceMap(collections_abc.Mapping):
cross_stack_resources,
):
self._template = template
self._resource_json_map = template["Resources"]
self._resource_json_map = template["Resources"] if template != {} else {}
self._region_name = region_name
self.input_parameters = parameters
self.tags = copy.deepcopy(tags)
@ -448,6 +439,7 @@ class ResourceMap(collections_abc.Mapping):
return self._parsed_resources[resource_logical_id]
else:
resource_json = self._resource_json_map.get(resource_logical_id)
if not resource_json:
raise KeyError(resource_logical_id)
new_resource = parse_and_create_resource(
@ -463,6 +455,34 @@ class ResourceMap(collections_abc.Mapping):
def __len__(self):
return len(self._resource_json_map)
def __get_resources_in_dependency_order(self):
resource_map = copy.deepcopy(self._resource_json_map)
resources_in_dependency_order = []
def recursively_get_dependencies(resource):
resource_info = resource_map[resource]
if "DependsOn" not in resource_info:
resources_in_dependency_order.append(resource)
del resource_map[resource]
return
dependencies = resource_info["DependsOn"]
if isinstance(dependencies, str): # Dependencies may be a string or list
dependencies = [dependencies]
for dependency in dependencies:
if dependency in resource_map:
recursively_get_dependencies(dependency)
resources_in_dependency_order.append(resource)
del resource_map[resource]
while resource_map:
recursively_get_dependencies(list(resource_map.keys())[0])
return resources_in_dependency_order
@property
def resources(self):
return self._resource_json_map.keys()
@ -470,6 +490,17 @@ class ResourceMap(collections_abc.Mapping):
def load_mapping(self):
self._parsed_resources.update(self._template.get("Mappings", {}))
def transform_mapping(self):
for k, v in self._template.get("Mappings", {}).items():
if "Fn::Transform" in v:
name = v["Fn::Transform"]["Name"]
params = v["Fn::Transform"]["Parameters"]
if name == "AWS::Include":
location = params["Location"]
bucket_name, name = bucket_and_name_from_url(location)
key = s3_backend.get_object(bucket_name, name)
self._parsed_resources.update(json.loads(key.value))
def load_parameters(self):
parameter_slots = self._template.get("Parameters", {})
for parameter_name, parameter in parameter_slots.items():
@ -486,6 +517,23 @@ class ResourceMap(collections_abc.Mapping):
if value_type == "CommaDelimitedList" or value_type.startswith("List"):
value = value.split(",")
def _parse_number_parameter(num_string):
"""CloudFormation NUMBER types can be an int or float.
Try int first and then fall back to float if that fails
"""
try:
return int(num_string)
except ValueError:
return float(num_string)
if value_type == "List<Number>":
# The if statement directly above already converted
# to a list. Now we convert each element to a number
value = [_parse_number_parameter(v) for v in value]
if value_type == "Number":
value = _parse_number_parameter(value)
if parameter_slot.get("NoEcho"):
self.no_echo_parameter_keys.append(key)
@ -513,20 +561,25 @@ class ResourceMap(collections_abc.Mapping):
for condition_name in self.lazy_condition_map:
self.lazy_condition_map[condition_name]
def create(self):
def load(self):
self.load_mapping()
self.transform_mapping()
self.load_parameters()
self.load_conditions()
def create(self, template):
# Since this is a lazy map, to create every object we just need to
# iterate through self.
# Assumes that self.load() has been called before
self._template = template
self._resource_json_map = template["Resources"]
self.tags.update(
{
"aws:cloudformation:stack-name": self.get("AWS::StackName"),
"aws:cloudformation:stack-id": self.get("AWS::StackId"),
}
)
for resource in self.resources:
for resource in self.__get_resources_in_dependency_order():
if isinstance(self[resource], ec2_models.TaggedEC2Resource):
self.tags["aws:cloudformation:logical-id"] = resource
ec2_models.ec2_backends[self._region_name].create_tags(
@ -588,28 +641,36 @@ class ResourceMap(collections_abc.Mapping):
)
self._parsed_resources[resource_name] = new_resource
for resource_name, resource in resources_by_action["Remove"].items():
resource_json = old_template[resource_name]
for logical_name, _ in resources_by_action["Remove"].items():
resource_json = old_template[logical_name]
resource = self._parsed_resources[logical_name]
# ToDo: Standardize this.
if hasattr(resource, "physical_resource_id"):
resource_name = self._parsed_resources[
logical_name
].physical_resource_id
else:
resource_name = None
parse_and_delete_resource(
resource_name, resource_json, self, self._region_name
)
self._parsed_resources.pop(resource_name)
self._parsed_resources.pop(logical_name)
tries = 1
while resources_by_action["Modify"] and tries < 5:
for resource_name, resource in resources_by_action["Modify"].copy().items():
resource_json = new_template[resource_name]
for logical_name, _ in resources_by_action["Modify"].copy().items():
resource_json = new_template[logical_name]
try:
changed_resource = parse_and_update_resource(
resource_name, resource_json, self, self._region_name
logical_name, resource_json, self, self._region_name
)
except Exception as e:
# skip over dependency violations, and try again in a
# second pass
last_exception = e
else:
self._parsed_resources[resource_name] = changed_resource
del resources_by_action["Modify"][resource_name]
self._parsed_resources[logical_name] = changed_resource
del resources_by_action["Modify"][logical_name]
tries += 1
if tries == 5:
raise last_exception
@ -623,6 +684,21 @@ class ResourceMap(collections_abc.Mapping):
try:
if parsed_resource and hasattr(parsed_resource, "delete"):
parsed_resource.delete(self._region_name)
else:
if hasattr(parsed_resource, "physical_resource_id"):
resource_name = parsed_resource.physical_resource_id
else:
resource_name = None
resource_json = self._resource_json_map[
parsed_resource.logical_resource_id
]
parse_and_delete_resource(
resource_name, resource_json, self, self._region_name,
)
self._parsed_resources.pop(parsed_resource.logical_resource_id)
except Exception as e:
# skip over dependency violations, and try again in a
# second pass

View File

@ -10,6 +10,31 @@ from moto.s3 import s3_backend
from moto.core import ACCOUNT_ID
from .models import cloudformation_backends
from .exceptions import ValidationError
from .utils import yaml_tag_constructor
def get_template_summary_response_from_template(template_body):
def get_resource_types(template_dict):
resources = {}
for key, value in template_dict.items():
if key == "Resources":
resources = value
resource_types = []
for key, value in resources.items():
resource_types.append(value["Type"])
return resource_types
yaml.add_multi_constructor("", yaml_tag_constructor)
try:
template_dict = yaml.load(template_body, Loader=yaml.Loader)
except (yaml.parser.ParserError, yaml.scanner.ScannerError):
template_dict = json.loads(template_body)
resources_types = get_resource_types(template_dict)
template_dict["resourceTypes"] = resources_types
return template_dict
class CloudFormationResponse(BaseResponse):
@ -36,7 +61,7 @@ class CloudFormationResponse(BaseResponse):
bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/")
key = s3_backend.get_key(bucket_name, key_name)
key = s3_backend.get_object(bucket_name, key_name)
return key.value.decode("utf-8")
def create_stack(self):
@ -50,6 +75,12 @@ class CloudFormationResponse(BaseResponse):
for item in self._get_list_prefix("Tags.member")
)
if self.stack_name_exists(new_stack_name=stack_name):
template = self.response_template(
CREATE_STACK_NAME_EXISTS_RESPONSE_TEMPLATE
)
return 400, {"status": 400}, template.render(name=stack_name)
# Hack dict-comprehension
parameters = dict(
[
@ -82,6 +113,12 @@ class CloudFormationResponse(BaseResponse):
template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE)
return template.render(stack=stack)
def stack_name_exists(self, new_stack_name):
for stack in self.cloudformation_backend.stacks.values():
if stack.name == new_stack_name:
return True
return False
@amzn_request_id
def create_change_set(self):
stack_name = self._get_param("StackName")
@ -221,7 +258,8 @@ class CloudFormationResponse(BaseResponse):
return template.render(change_sets=change_sets)
def list_stacks(self):
stacks = self.cloudformation_backend.list_stacks()
status_filter = self._get_multi_param("StackStatusFilter.member")
stacks = self.cloudformation_backend.list_stacks(status_filter)
template = self.response_template(LIST_STACKS_RESPONSE)
return template.render(stacks=stacks)
@ -256,6 +294,20 @@ class CloudFormationResponse(BaseResponse):
template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE)
return template.render(stack=stack)
def get_template_summary(self):
stack_name = self._get_param("StackName")
template_url = self._get_param("TemplateURL")
stack_body = self._get_param("TemplateBody")
if stack_name:
stack_body = self.cloudformation_backend.get_stack(stack_name).template
elif template_url:
stack_body = self._get_stack_from_s3_url(template_url)
template_summary = get_template_summary_response_from_template(stack_body)
template = self.response_template(GET_TEMPLATE_SUMMARY_TEMPLATE)
return template.render(template_summary=template_summary)
def update_stack(self):
stack_name = self._get_param("StackName")
role_arn = self._get_param("RoleARN")
@ -339,19 +391,22 @@ class CloudFormationResponse(BaseResponse):
return template.render(exports=exports, next_token=next_token)
def validate_template(self):
cfn_lint = self.cloudformation_backend.validate_template(
self._get_param("TemplateBody")
)
template_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
if template_url:
template_body = self._get_stack_from_s3_url(template_url)
cfn_lint = self.cloudformation_backend.validate_template(template_body)
if cfn_lint:
raise ValidationError(cfn_lint[0].message)
description = ""
try:
description = json.loads(self._get_param("TemplateBody"))["Description"]
description = json.loads(template_body)["Description"]
except (ValueError, KeyError):
pass
try:
description = yaml.load(self._get_param("TemplateBody"))["Description"]
except (yaml.ParserError, KeyError):
description = yaml.load(template_body, Loader=yaml.Loader)["Description"]
except (yaml.parser.ParserError, yaml.scanner.ScannerError, KeyError):
pass
template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE)
return template.render(description=description)
@ -564,6 +619,15 @@ CREATE_STACK_RESPONSE_TEMPLATE = """<CreateStackResponse>
</CreateStackResponse>
"""
CREATE_STACK_NAME_EXISTS_RESPONSE_TEMPLATE = """<ErrorResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/">
<Error>
<Type>Sender</Type>
<Code>AlreadyExistsException</Code>
<Message>Stack [{{ name }}] already exists</Message>
</Error>
<RequestId>950ff8d7-812a-44b3-bb0c-9b271b954104</RequestId>
</ErrorResponse>"""
UPDATE_STACK_RESPONSE_TEMPLATE = """<UpdateStackResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/">
<UpdateStackResult>
<StackId>{{ stack.stack_id }}</StackId>
@ -609,7 +673,7 @@ DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE = """<DescribeChangeSetResponse>
</member>
{% endfor %}
</Parameters>
<CreationTime>2011-05-23T15:47:44Z</CreationTime>
<CreationTime>{{ change_set.creation_time_iso_8601 }}</CreationTime>
<ExecutionStatus>{{ change_set.execution_status }}</ExecutionStatus>
<Status>{{ change_set.status }}</Status>
<StatusReason>{{ change_set.status_reason }}</StatusReason>
@ -662,7 +726,11 @@ DESCRIBE_STACKS_TEMPLATE = """<DescribeStacksResponse>
<member>
<StackName>{{ stack.name }}</StackName>
<StackId>{{ stack.stack_id }}</StackId>
<CreationTime>2010-07-27T22:28:28Z</CreationTime>
{% if stack.change_set_id %}
<ChangeSetId>{{ stack.change_set_id }}</ChangeSetId>
{% endif %}
<Description>{{ stack.description }}</Description>
<CreationTime>{{ stack.creation_time_iso_8601 }}</CreationTime>
<StackStatus>{{ stack.status }}</StackStatus>
{% if stack.notification_arns %}
<NotificationARNs>
@ -714,7 +782,6 @@ DESCRIBE_STACKS_TEMPLATE = """<DescribeStacksResponse>
</DescribeStacksResult>
</DescribeStacksResponse>"""
DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE = """<DescribeStackResourceResponse>
<DescribeStackResourceResult>
<StackResourceDetail>
@ -729,7 +796,6 @@ DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE = """<DescribeStackResourceResponse>
</DescribeStackResourceResult>
</DescribeStackResourceResponse>"""
DESCRIBE_STACK_RESOURCES_RESPONSE = """<DescribeStackResourcesResponse>
<DescribeStackResourcesResult>
<StackResources>
@ -748,7 +814,6 @@ DESCRIBE_STACK_RESOURCES_RESPONSE = """<DescribeStackResourcesResponse>
</DescribeStackResourcesResult>
</DescribeStackResourcesResponse>"""
DESCRIBE_STACK_EVENTS_RESPONSE = """<DescribeStackEventsResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/">
<DescribeStackEventsResult>
<StackEvents>
@ -773,7 +838,6 @@ DESCRIBE_STACK_EVENTS_RESPONSE = """<DescribeStackEventsResponse xmlns="http://c
</ResponseMetadata>
</DescribeStackEventsResponse>"""
LIST_CHANGE_SETS_RESPONSE = """<ListChangeSetsResponse>
<ListChangeSetsResult>
<Summaries>
@ -794,7 +858,6 @@ LIST_CHANGE_SETS_RESPONSE = """<ListChangeSetsResponse>
</ListChangeSetsResult>
</ListChangeSetsResponse>"""
LIST_STACKS_RESPONSE = """<ListStacksResponse>
<ListStacksResult>
<StackSummaries>
@ -803,7 +866,7 @@ LIST_STACKS_RESPONSE = """<ListStacksResponse>
<StackId>{{ stack.stack_id }}</StackId>
<StackStatus>{{ stack.status }}</StackStatus>
<StackName>{{ stack.name }}</StackName>
<CreationTime>2011-05-23T15:47:44Z</CreationTime>
<CreationTime>{{ stack.creation_time_iso_8601 }}</CreationTime>
<TemplateDescription>{{ stack.description }}</TemplateDescription>
</member>
{% endfor %}
@ -811,7 +874,6 @@ LIST_STACKS_RESPONSE = """<ListStacksResponse>
</ListStacksResult>
</ListStacksResponse>"""
LIST_STACKS_RESOURCES_RESPONSE = """<ListStackResourcesResponse>
<ListStackResourcesResult>
<StackResourceSummaries>
@ -831,7 +893,6 @@ LIST_STACKS_RESOURCES_RESPONSE = """<ListStackResourcesResponse>
</ResponseMetadata>
</ListStackResourcesResponse>"""
GET_TEMPLATE_RESPONSE_TEMPLATE = """<GetTemplateResponse>
<GetTemplateResult>
<TemplateBody>{{ stack.template }}</TemplateBody>
@ -841,7 +902,6 @@ GET_TEMPLATE_RESPONSE_TEMPLATE = """<GetTemplateResponse>
</ResponseMetadata>
</GetTemplateResponse>"""
DELETE_STACK_RESPONSE_TEMPLATE = """<DeleteStackResponse>
<ResponseMetadata>
<RequestId>5ccc7dcd-744c-11e5-be70-example</RequestId>
@ -849,7 +909,6 @@ DELETE_STACK_RESPONSE_TEMPLATE = """<DeleteStackResponse>
</DeleteStackResponse>
"""
LIST_EXPORTS_RESPONSE = """<ListExportsResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/">
<ListExportsResult>
<Exports>
@ -1110,3 +1169,19 @@ LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = (
</ListStackSetOperationResultsResponse>
"""
)
GET_TEMPLATE_SUMMARY_TEMPLATE = """<GetTemplateSummaryResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/">
<GetTemplateSummaryResult>
<Description>{{ template_summary.Description }}</Description>
{% for resource in template_summary.resourceTypes %}
<ResourceTypes>
<ResourceType>{{ resource }}</ResourceType>
</ResourceTypes>
{% endfor %}
<Version>{{ template_summary.AWSTemplateFormatVersion }}</Version>
</GetTemplateSummaryResult>
<ResponseMetadata>
<RequestId>b9b4b068-3a41-11e5-94eb-example</RequestId>
</ResponseMetadata>
</GetTemplateSummaryResponse>
"""

View File

@ -6,7 +6,6 @@ import yaml
import os
import string
from cfnlint import decode, core
from moto.core import ACCOUNT_ID
@ -42,8 +41,7 @@ def random_suffix():
def yaml_tag_constructor(loader, tag, node):
"""convert shorthand intrinsic function to full name
"""
"""convert shorthand intrinsic function to full name"""
def _f(loader, tag, node):
if tag == "!GetAtt":
@ -62,6 +60,8 @@ def yaml_tag_constructor(loader, tag, node):
def validate_template_cfn_lint(template):
# Importing cfnlint adds a significant overhead, so we keep it local
from cfnlint import decode, core
# Save the template to a temporary file -- cfn-lint requires a file
filename = "file.tmp"
@ -70,7 +70,12 @@ def validate_template_cfn_lint(template):
abs_filename = os.path.abspath(filename)
# decode handles both yaml and json
template, matches = decode.decode(abs_filename, False)
try:
template, matches = decode.decode(abs_filename, False)
except TypeError:
# As of cfn-lint 0.39.0, the second argument (ignore_bad_template) was dropped
# https://github.com/aws-cloudformation/cfn-python-lint/pull/1580
template, matches = decode.decode(abs_filename)
# Set cfn-lint to info
core.configure_logging(None)

View File

@ -2,13 +2,15 @@ import json
from boto3 import Session
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.core.exceptions import RESTError
from moto.logs import logs_backends
from datetime import datetime, timedelta
from dateutil.tz import tzutc
from uuid import uuid4
from .utils import make_arn_for_dashboard
from .utils import make_arn_for_dashboard, make_arn_for_alarm
from dateutil import parser
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
@ -20,6 +22,41 @@ class Dimension(object):
self.name = name
self.value = value
def __eq__(self, item):
if isinstance(item, Dimension):
return self.name == item.name and self.value == item.value
return False
def __ne__(self, item): # Only needed on Py2; Py3 defines it implicitly
return self != item
class Metric(object):
def __init__(self, metric_name, namespace, dimensions):
self.metric_name = metric_name
self.namespace = namespace
self.dimensions = dimensions
class MetricStat(object):
def __init__(self, metric, period, stat, unit):
self.metric = metric
self.period = period
self.stat = stat
self.unit = unit
class MetricDataQuery(object):
def __init__(
self, id, label, period, return_data, expression=None, metric_stat=None
):
self.id = id
self.label = label
self.period = period
self.return_data = return_data
self.expression = expression
self.metric_stat = metric_stat
def daterange(start, stop, step=timedelta(days=1), inclusive=False):
"""
@ -55,8 +92,10 @@ class FakeAlarm(BaseModel):
name,
namespace,
metric_name,
metric_data_queries,
comparison_operator,
evaluation_periods,
datapoints_to_alarm,
period,
threshold,
statistic,
@ -66,12 +105,17 @@ class FakeAlarm(BaseModel):
ok_actions,
insufficient_data_actions,
unit,
actions_enabled,
region="us-east-1",
):
self.name = name
self.alarm_arn = make_arn_for_alarm(region, DEFAULT_ACCOUNT_ID, name)
self.namespace = namespace
self.metric_name = metric_name
self.metric_data_queries = metric_data_queries
self.comparison_operator = comparison_operator
self.evaluation_periods = evaluation_periods
self.datapoints_to_alarm = datapoints_to_alarm
self.period = period
self.threshold = threshold
self.statistic = statistic
@ -79,6 +123,7 @@ class FakeAlarm(BaseModel):
self.dimensions = [
Dimension(dimension["name"], dimension["value"]) for dimension in dimensions
]
self.actions_enabled = actions_enabled
self.alarm_actions = alarm_actions
self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions
@ -110,6 +155,18 @@ class FakeAlarm(BaseModel):
self.state_updated_timestamp = datetime.utcnow()
def are_dimensions_same(metric_dimensions, dimensions):
for dimension in metric_dimensions:
for new_dimension in dimensions:
if (
dimension.name != new_dimension.name
or dimension.value != new_dimension.value
):
return False
return True
class MetricDatum(BaseModel):
def __init__(self, namespace, name, value, dimensions, timestamp):
self.namespace = namespace
@ -120,6 +177,23 @@ class MetricDatum(BaseModel):
Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions
]
def filter(self, namespace, name, dimensions, already_present_metrics):
if namespace and namespace != self.namespace:
return False
if name and name != self.name:
return False
for metric in already_present_metrics:
if self.dimensions and are_dimensions_same(
metric.dimensions, self.dimensions
):
return False
if dimensions and any(
Dimension(d["Name"], d["Value"]) not in self.dimensions for d in dimensions
):
return False
return True
class Dashboard(BaseModel):
def __init__(self, name, body):
@ -146,7 +220,7 @@ class Dashboard(BaseModel):
class Statistics:
def __init__(self, stats, dt):
self.timestamp = iso_8601_datetime_with_milliseconds(dt)
self.timestamp = iso_8601_datetime_without_milliseconds(dt)
self.values = []
self.stats = stats
@ -198,13 +272,24 @@ class CloudWatchBackend(BaseBackend):
self.metric_data = []
self.paged_metric_data = {}
@property
# Retrieve a list of all OOTB metrics that are provided by metrics providers
# Computed on the fly
def aws_metric_data(self):
md = []
for name, service in metric_providers.items():
md.extend(service.get_cloudwatch_metrics())
return md
def put_metric_alarm(
self,
name,
namespace,
metric_name,
metric_data_queries,
comparison_operator,
evaluation_periods,
datapoints_to_alarm,
period,
threshold,
statistic,
@ -214,13 +299,17 @@ class CloudWatchBackend(BaseBackend):
ok_actions,
insufficient_data_actions,
unit,
actions_enabled,
region="us-east-1",
):
alarm = FakeAlarm(
name,
namespace,
metric_name,
metric_data_queries,
comparison_operator,
evaluation_periods,
datapoints_to_alarm,
period,
threshold,
statistic,
@ -230,7 +319,10 @@ class CloudWatchBackend(BaseBackend):
ok_actions,
insufficient_data_actions,
unit,
actions_enabled,
region,
)
self.alarms[name] = alarm
return alarm
@ -270,6 +362,13 @@ class CloudWatchBackend(BaseBackend):
)
def delete_alarms(self, alarm_names):
for alarm_name in alarm_names:
if alarm_name not in self.alarms:
raise RESTError(
"ResourceNotFound",
"Alarm {0} not found".format(alarm_name),
status=404,
)
for alarm_name in alarm_names:
self.alarms.pop(alarm_name, None)
@ -278,8 +377,7 @@ class CloudWatchBackend(BaseBackend):
# Preserve "datetime" for get_metric_statistics comparisons
timestamp = metric_member.get("Timestamp")
if timestamp is not None and type(timestamp) != datetime:
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
timestamp = timestamp.replace(tzinfo=tzutc())
timestamp = parser.parse(timestamp)
self.metric_data.append(
MetricDatum(
namespace,
@ -290,6 +388,43 @@ class CloudWatchBackend(BaseBackend):
)
)
def get_metric_data(self, queries, start_time, end_time):
period_data = [
md for md in self.metric_data if start_time <= md.timestamp <= end_time
]
results = []
for query in queries:
query_ns = query["metric_stat._metric._namespace"]
query_name = query["metric_stat._metric._metric_name"]
query_data = [
md
for md in period_data
if md.namespace == query_ns and md.name == query_name
]
metric_values = [m.value for m in query_data]
result_vals = []
stat = query["metric_stat._stat"]
if len(metric_values) > 0:
if stat == "Average":
result_vals.append(sum(metric_values) / len(metric_values))
elif stat == "Minimum":
result_vals.append(min(metric_values))
elif stat == "Maximum":
result_vals.append(max(metric_values))
elif stat == "Sum":
result_vals.append(sum(metric_values))
label = query["metric_stat._metric._metric_name"] + " " + stat
results.append(
{
"id": query["id"],
"label": label,
"vals": result_vals,
"timestamps": [datetime.now() for _ in result_vals],
}
)
return results
def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats
):
@ -329,7 +464,7 @@ class CloudWatchBackend(BaseBackend):
return data
def get_all_metrics(self):
return self.metric_data
return self.metric_data + self.aws_metric_data
def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body)
@ -381,7 +516,7 @@ class CloudWatchBackend(BaseBackend):
self.alarms[alarm_name].update_state(reason, reason_data, state_value)
def list_metrics(self, next_token, namespace, metric_name):
def list_metrics(self, next_token, namespace, metric_name, dimensions):
if next_token:
if next_token not in self.paged_metric_data:
raise RESTError(
@ -392,16 +527,21 @@ class CloudWatchBackend(BaseBackend):
del self.paged_metric_data[next_token] # Cant reuse same token twice
return self._get_paginated(metrics)
else:
metrics = self.get_filtered_metrics(metric_name, namespace)
metrics = self.get_filtered_metrics(metric_name, namespace, dimensions)
return self._get_paginated(metrics)
def get_filtered_metrics(self, metric_name, namespace):
def get_filtered_metrics(self, metric_name, namespace, dimensions):
metrics = self.get_all_metrics()
if namespace:
metrics = [md for md in metrics if md.namespace == namespace]
if metric_name:
metrics = [md for md in metrics if md.name == metric_name]
return metrics
new_metrics = []
for md in metrics:
if md.filter(
namespace=namespace,
name=metric_name,
dimensions=dimensions,
already_present_metrics=new_metrics,
):
new_metrics.append(md)
return new_metrics
def _get_paginated(self, metrics):
if len(metrics) > 500:
@ -412,24 +552,31 @@ class CloudWatchBackend(BaseBackend):
return None, metrics
class LogGroup(BaseModel):
class LogGroup(CloudFormationModel):
def __init__(self, spec):
# required
self.name = spec["LogGroupName"]
# optional
self.tags = spec.get("Tags", [])
@staticmethod
def cloudformation_name_type():
return "LogGroupName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-logs-loggroup.html
return "AWS::Logs::LogGroup"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
spec = {"LogGroupName": properties["LogGroupName"]}
optional_properties = "Tags".split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]
return LogGroup(spec)
tags = properties.get("Tags", {})
return logs_backends[region_name].create_log_group(
resource_name, tags, **properties
)
cloudwatch_backends = {}
@ -441,3 +588,8 @@ for region in Session().get_available_regions(
cloudwatch_backends[region] = CloudWatchBackend()
for region in Session().get_available_regions("cloudwatch", partition_name="aws-cn"):
cloudwatch_backends[region] = CloudWatchBackend()
# List of services that provide OOTB CW metrics
# See the S3Backend constructor for an example
# TODO: We might have to separate this out per region for non-global services
metric_providers = {}

View File

@ -1,7 +1,7 @@
import json
from moto.core.utils import amzn_request_id
from moto.core.responses import BaseResponse
from .models import cloudwatch_backends
from .models import cloudwatch_backends, MetricDataQuery, MetricStat, Metric, Dimension
from dateutil.parser import parse as dtparse
@ -19,8 +19,37 @@ class CloudWatchResponse(BaseResponse):
name = self._get_param("AlarmName")
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
metrics = self._get_multi_param("Metrics.member")
metric_data_queries = None
if metrics:
metric_data_queries = [
MetricDataQuery(
id=metric.get("Id"),
label=metric.get("Label"),
period=metric.get("Period"),
return_data=metric.get("ReturnData"),
expression=metric.get("Expression"),
metric_stat=MetricStat(
metric=Metric(
metric_name=metric.get("MetricStat.Metric.MetricName"),
namespace=metric.get("MetricStat.Metric.Namespace"),
dimensions=[
Dimension(name=dim["Name"], value=dim["Value"])
for dim in metric["MetricStat.Metric.Dimensions.member"]
],
),
period=metric.get("MetricStat.Period"),
stat=metric.get("MetricStat.Stat"),
unit=metric.get("MetricStat.Unit"),
)
if "MetricStat.Metric.MetricName" in metric
else None,
)
for metric in metrics
]
comparison_operator = self._get_param("ComparisonOperator")
evaluation_periods = self._get_param("EvaluationPeriods")
datapoints_to_alarm = self._get_param("DatapointsToAlarm")
period = self._get_param("Period")
threshold = self._get_param("Threshold")
statistic = self._get_param("Statistic")
@ -28,6 +57,7 @@ class CloudWatchResponse(BaseResponse):
dimensions = self._get_list_prefix("Dimensions.member")
alarm_actions = self._get_multi_param("AlarmActions.member")
ok_actions = self._get_multi_param("OKActions.member")
actions_enabled = self._get_param("ActionsEnabled")
insufficient_data_actions = self._get_multi_param(
"InsufficientDataActions.member"
)
@ -36,8 +66,10 @@ class CloudWatchResponse(BaseResponse):
name,
namespace,
metric_name,
metric_data_queries,
comparison_operator,
evaluation_periods,
datapoints_to_alarm,
period,
threshold,
statistic,
@ -47,6 +79,8 @@ class CloudWatchResponse(BaseResponse):
ok_actions,
insufficient_data_actions,
unit,
actions_enabled,
self.region,
)
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm)
@ -90,6 +124,18 @@ class CloudWatchResponse(BaseResponse):
template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
return template.render()
@amzn_request_id
def get_metric_data(self):
start = dtparse(self._get_param("StartTime"))
end = dtparse(self._get_param("EndTime"))
queries = self._get_list_prefix("MetricDataQueries.member")
results = self.cloudwatch_backend.get_metric_data(
start_time=start, end_time=end, queries=queries
)
template = self.response_template(GET_METRIC_DATA_TEMPLATE)
return template.render(results=results)
@amzn_request_id
def get_metric_statistics(self):
namespace = self._get_param("Namespace")
@ -122,9 +168,10 @@ class CloudWatchResponse(BaseResponse):
def list_metrics(self):
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
dimensions = self._get_multi_param("Dimensions.member")
next_token = self._get_param("NextToken")
next_token, metrics = self.cloudwatch_backend.list_metrics(
next_token, namespace, metric_name
next_token, namespace, metric_name, dimensions
)
template = self.response_template(LIST_METRICS_TEMPLATE)
return template.render(metrics=metrics, next_token=next_token)
@ -146,9 +193,23 @@ class CloudWatchResponse(BaseResponse):
def describe_alarm_history(self):
raise NotImplementedError()
@staticmethod
def filter_alarms(alarms, metric_name, namespace):
metric_filtered_alarms = []
for alarm in alarms:
if alarm.metric_name == metric_name and alarm.namespace == namespace:
metric_filtered_alarms.append(alarm)
return metric_filtered_alarms
@amzn_request_id
def describe_alarms_for_metric(self):
raise NotImplementedError()
alarms = self.cloudwatch_backend.get_all_alarms()
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
filtered_alarms = self.filter_alarms(alarms, metric_name, namespace)
template = self.response_template(DESCRIBE_METRIC_ALARMS_TEMPLATE)
return template.render(alarms=filtered_alarms)
@amzn_request_id
def disable_alarm_actions(self):
@ -227,7 +288,115 @@ DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.a
<member>{{ action }}</member>
{% endfor %}
</AlarmActions>
<AlarmArn>{{ alarm.arn }}</AlarmArn>
<AlarmArn>{{ alarm.alarm_arn }}</AlarmArn>
<AlarmConfigurationUpdatedTimestamp>{{ alarm.configuration_updated_timestamp }}</AlarmConfigurationUpdatedTimestamp>
<AlarmDescription>{{ alarm.description }}</AlarmDescription>
<AlarmName>{{ alarm.name }}</AlarmName>
<ComparisonOperator>{{ alarm.comparison_operator }}</ComparisonOperator>
{% if alarm.dimensions is not none %}
<Dimensions>
{% for dimension in alarm.dimensions %}
<member>
<Name>{{ dimension.name }}</Name>
<Value>{{ dimension.value }}</Value>
</member>
{% endfor %}
</Dimensions>
{% endif %}
<EvaluationPeriods>{{ alarm.evaluation_periods }}</EvaluationPeriods>
{% if alarm.datapoints_to_alarm is not none %}
<DatapointsToAlarm>{{ alarm.datapoints_to_alarm }}</DatapointsToAlarm>
{% endif %}
<InsufficientDataActions>
{% for action in alarm.insufficient_data_actions %}
<member>{{ action }}</member>
{% endfor %}
</InsufficientDataActions>
{% if alarm.metric_name is not none %}
<MetricName>{{ alarm.metric_name }}</MetricName>
{% endif %}
{% if alarm.metric_data_queries is not none %}
<Metrics>
{% for metric in alarm.metric_data_queries %}
<member>
<Id>{{ metric.id }}</Id>
{% if metric.label is not none %}
<Label>{{ metric.label }}</Label>
{% endif %}
{% if metric.expression is not none %}
<Expression>{{ metric.expression }}</Expression>
{% endif %}
{% if metric.metric_stat is not none %}
<MetricStat>
<Metric>
<Namespace>{{ metric.metric_stat.metric.namespace }}</Namespace>
<MetricName>{{ metric.metric_stat.metric.metric_name }}</MetricName>
<Dimensions>
{% for dim in metric.metric_stat.metric.dimensions %}
<member>
<Name>{{ dim.name }}</Name>
<Value>{{ dim.value }}</Value>
</member>
{% endfor %}
</Dimensions>
</Metric>
{% if metric.metric_stat.period is not none %}
<Period>{{ metric.metric_stat.period }}</Period>
{% endif %}
<Stat>{{ metric.metric_stat.stat }}</Stat>
{% if metric.metric_stat.unit is not none %}
<Unit>{{ metric.metric_stat.unit }}</Unit>
{% endif %}
</MetricStat>
{% endif %}
{% if metric.period is not none %}
<Period>{{ metric.period }}</Period>
{% endif %}
<ReturnData>{{ metric.return_data }}</ReturnData>
</member>
{% endfor %}
</Metrics>
{% endif %}
{% if alarm.namespace is not none %}
<Namespace>{{ alarm.namespace }}</Namespace>
{% endif %}
<OKActions>
{% for action in alarm.ok_actions %}
<member>{{ action }}</member>
{% endfor %}
</OKActions>
{% if alarm.period is not none %}
<Period>{{ alarm.period }}</Period>
{% endif %}
<StateReason>{{ alarm.state_reason }}</StateReason>
<StateReasonData>{{ alarm.state_reason_data }}</StateReasonData>
<StateUpdatedTimestamp>{{ alarm.state_updated_timestamp }}</StateUpdatedTimestamp>
<StateValue>{{ alarm.state_value }}</StateValue>
{% if alarm.statistic is not none %}
<Statistic>{{ alarm.statistic }}</Statistic>
{% endif %}
<Threshold>{{ alarm.threshold }}</Threshold>
{% if alarm.unit is not none %}
<Unit>{{ alarm.unit }}</Unit>
{% endif %}
</member>
{% endfor %}
</MetricAlarms>
</DescribeAlarmsResult>
</DescribeAlarmsResponse>"""
DESCRIBE_METRIC_ALARMS_TEMPLATE = """<DescribeAlarmsForMetricResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<DescribeAlarmsForMetricResult>
<MetricAlarms>
{% for alarm in alarms %}
<member>
<ActionsEnabled>{{ alarm.actions_enabled }}</ActionsEnabled>
<AlarmActions>
{% for action in alarm.alarm_actions %}
<member>{{ action }}</member>
{% endfor %}
</AlarmActions>
<AlarmArn>{{ alarm.alarm_arn }}</AlarmArn>
<AlarmConfigurationUpdatedTimestamp>{{ alarm.configuration_updated_timestamp }}</AlarmConfigurationUpdatedTimestamp>
<AlarmDescription>{{ alarm.description }}</AlarmDescription>
<AlarmName>{{ alarm.name }}</AlarmName>
@ -264,8 +433,8 @@ DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.a
</member>
{% endfor %}
</MetricAlarms>
</DescribeAlarmsResult>
</DescribeAlarmsResponse>"""
</DescribeAlarmsForMetricResult>
</DescribeAlarmsForMetricResponse>"""
DELETE_METRIC_ALARMS_TEMPLATE = """<DeleteMetricAlarmResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
@ -283,6 +452,35 @@ PUT_METRIC_DATA_TEMPLATE = """<PutMetricDataResponse xmlns="http://monitoring.am
</ResponseMetadata>
</PutMetricDataResponse>"""
GET_METRIC_DATA_TEMPLATE = """<GetMetricDataResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
<RequestId>
{{ request_id }}
</RequestId>
</ResponseMetadata>
<GetMetricDataResult>
<MetricDataResults>
{% for result in results %}
<member>
<Id>{{ result.id }}</Id>
<Label>{{ result.label }}</Label>
<StatusCode>Complete</StatusCode>
<Timestamps>
{% for val in result.timestamps %}
<member>{{ val }}</member>
{% endfor %}
</Timestamps>
<Values>
{% for val in result.vals %}
<member>{{ val }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</MetricDataResults>
</GetMetricDataResult>
</GetMetricDataResponse>"""
GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
<RequestId>

View File

@ -3,3 +3,7 @@ from __future__ import unicode_literals
def make_arn_for_dashboard(account_id, name):
return "arn:aws:cloudwatch::{0}dashboard/{1}".format(account_id, name)
def make_arn_for_alarm(region, account_id, alarm_name):
return "arn:aws:cloudwatch:{0}:{1}:alarm:{2}".format(region, account_id, alarm_name)

View File

@ -2,7 +2,7 @@ from boto3 import Session
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from datetime import datetime
from moto.iam.models import ACCOUNT_ID
from moto.core import ACCOUNT_ID
from .exceptions import RepositoryDoesNotExistException, RepositoryNameExistsException
import uuid

View File

@ -15,9 +15,7 @@ from moto.codepipeline.exceptions import (
InvalidTagsException,
TooManyTagsException,
)
from moto.core import BaseBackend, BaseModel
from moto.iam.models import ACCOUNT_ID
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
class CodePipeline(BaseModel):

View File

@ -1,5 +1,5 @@
from moto.core.utils import get_random_hex
from uuid import uuid4
def get_random_identity_id(region):
return "{0}:{1}".format(region, get_random_hex(length=19))
return "{0}:{1}".format(region, uuid4())

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import json
from werkzeug.exceptions import BadRequest
from moto.core.exceptions import JsonRESTError
class ResourceNotFoundError(BadRequest):
@ -42,3 +43,19 @@ class NotAuthorizedError(BadRequest):
self.description = json.dumps(
{"message": message, "__type": "NotAuthorizedException"}
)
class UserNotConfirmedException(BadRequest):
def __init__(self, message):
super(UserNotConfirmedException, self).__init__()
self.description = json.dumps(
{"message": message, "__type": "UserNotConfirmedException"}
)
class InvalidParameterException(JsonRESTError):
def __init__(self, msg=None):
self.code = 400
super(InvalidParameterException, self).__init__(
"InvalidParameterException", msg or "A parameter is specified incorrectly."
)

View File

@ -14,17 +14,22 @@ from jose import jws
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
from .exceptions import (
GroupExistsException,
NotAuthorizedError,
ResourceNotFoundError,
UserNotFoundError,
UsernameExistsException,
UserNotConfirmedException,
InvalidParameterException,
)
from .utils import create_id, check_secret_hash
UserStatus = {
"FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD",
"CONFIRMED": "CONFIRMED",
"UNCONFIRMED": "UNCONFIRMED",
}
@ -69,6 +74,9 @@ class CognitoIdpUserPool(BaseModel):
def __init__(self, region, name, extended_config):
self.region = region
self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex))
self.arn = "arn:aws:cognito-idp:{}:{}:userpool/{}".format(
self.region, DEFAULT_ACCOUNT_ID, self.id
)
self.name = name
self.status = None
self.extended_config = extended_config or {}
@ -79,6 +87,7 @@ class CognitoIdpUserPool(BaseModel):
self.identity_providers = OrderedDict()
self.groups = OrderedDict()
self.users = OrderedDict()
self.resource_servers = OrderedDict()
self.refresh_tokens = {}
self.access_tokens = {}
self.id_tokens = {}
@ -91,6 +100,7 @@ class CognitoIdpUserPool(BaseModel):
def _base_json(self):
return {
"Id": self.id,
"Arn": self.arn,
"Name": self.name,
"Status": self.status,
"CreationDate": time.mktime(self.creation_date.timetuple()),
@ -123,8 +133,12 @@ class CognitoIdpUserPool(BaseModel):
"exp": now + expires_in,
}
payload.update(extra_data)
headers = {"kid": "dummy"} # KID as present in jwks-public.json
return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in
return (
jws.sign(payload, self.json_web_key, headers, algorithm="RS256"),
expires_in,
)
def create_id_token(self, client_id, username):
extra_data = self.get_user_extra_data_by_client_id(client_id, username)
@ -201,10 +215,11 @@ class CognitoIdpUserPoolDomain(BaseModel):
class CognitoIdpUserPoolClient(BaseModel):
def __init__(self, user_pool_id, extended_config):
def __init__(self, user_pool_id, generate_secret, extended_config):
self.user_pool_id = user_pool_id
self.id = str(uuid.uuid4())
self.id = create_id()
self.secret = str(uuid.uuid4())
self.generate_secret = generate_secret or False
self.extended_config = extended_config or {}
def _base_json(self):
@ -216,6 +231,8 @@ class CognitoIdpUserPoolClient(BaseModel):
def to_json(self, extended=False):
user_pool_client_json = self._base_json()
if self.generate_secret:
user_pool_client_json.update({"ClientSecret": self.secret})
if extended:
user_pool_client_json.update(self.extended_config)
@ -285,6 +302,9 @@ class CognitoIdpUser(BaseModel):
self.attributes = attributes
self.create_date = datetime.datetime.utcnow()
self.last_modified_date = datetime.datetime.utcnow()
self.sms_mfa_enabled = False
self.software_token_mfa_enabled = False
self.token_verified = False
# Groups this user is a member of.
# Note that these links are bidirectional.
@ -301,6 +321,11 @@ class CognitoIdpUser(BaseModel):
# list_users brings back "Attributes" while admin_get_user brings back "UserAttributes".
def to_json(self, extended=False, attributes_key="Attributes"):
user_mfa_setting_list = []
if self.software_token_mfa_enabled:
user_mfa_setting_list.append("SOFTWARE_TOKEN_MFA")
elif self.sms_mfa_enabled:
user_mfa_setting_list.append("SMS_MFA")
user_json = self._base_json()
if extended:
user_json.update(
@ -308,6 +333,7 @@ class CognitoIdpUser(BaseModel):
"Enabled": self.enabled,
attributes_key: self.attributes,
"MFAOptions": [],
"UserMFASettingList": user_mfa_setting_list,
}
)
@ -325,6 +351,26 @@ class CognitoIdpUser(BaseModel):
self.attributes = expand_attrs(flat_attributes)
class CognitoResourceServer(BaseModel):
def __init__(self, user_pool_id, identifier, name, scopes):
self.user_pool_id = user_pool_id
self.identifier = identifier
self.name = name
self.scopes = scopes
def to_json(self):
res = {
"UserPoolId": self.user_pool_id,
"Identifier": self.identifier,
"Name": self.name,
}
if len(self.scopes) != 0:
res.update({"Scopes": self.scopes})
return res
class CognitoIdpBackend(BaseBackend):
def __init__(self, region):
super(CognitoIdpBackend, self).__init__()
@ -393,12 +439,14 @@ class CognitoIdpBackend(BaseBackend):
return user_pool_domain
# User pool client
def create_user_pool_client(self, user_pool_id, extended_config):
def create_user_pool_client(self, user_pool_id, generate_secret, extended_config):
user_pool = self.user_pools.get(user_pool_id)
if not user_pool:
raise ResourceNotFoundError(user_pool_id)
user_pool_client = CognitoIdpUserPoolClient(user_pool_id, extended_config)
user_pool_client = CognitoIdpUserPoolClient(
user_pool_id, generate_secret, extended_config
)
user_pool.clients[user_pool_client.id] = user_pool_client
return user_pool_client
@ -693,6 +741,9 @@ class CognitoIdpBackend(BaseBackend):
def respond_to_auth_challenge(
self, session, client_id, challenge_name, challenge_responses
):
if challenge_name == "PASSWORD_VERIFIER":
session = challenge_responses.get("PASSWORD_CLAIM_SECRET_BLOCK")
user_pool = self.sessions.get(session)
if not user_pool:
raise ResourceNotFoundError(session)
@ -713,6 +764,62 @@ class CognitoIdpBackend(BaseBackend):
del self.sessions[session]
return self._log_user_in(user_pool, client, username)
elif challenge_name == "PASSWORD_VERIFIER":
username = challenge_responses.get("USERNAME")
user = user_pool.users.get(username)
if not user:
raise UserNotFoundError(username)
password_claim_signature = challenge_responses.get(
"PASSWORD_CLAIM_SIGNATURE"
)
if not password_claim_signature:
raise ResourceNotFoundError(password_claim_signature)
password_claim_secret_block = challenge_responses.get(
"PASSWORD_CLAIM_SECRET_BLOCK"
)
if not password_claim_secret_block:
raise ResourceNotFoundError(password_claim_secret_block)
timestamp = challenge_responses.get("TIMESTAMP")
if not timestamp:
raise ResourceNotFoundError(timestamp)
if user.software_token_mfa_enabled:
return {
"ChallengeName": "SOFTWARE_TOKEN_MFA",
"Session": session,
"ChallengeParameters": {},
}
if user.sms_mfa_enabled:
return {
"ChallengeName": "SMS_MFA",
"Session": session,
"ChallengeParameters": {},
}
del self.sessions[session]
return self._log_user_in(user_pool, client, username)
elif challenge_name == "SOFTWARE_TOKEN_MFA":
username = challenge_responses.get("USERNAME")
user = user_pool.users.get(username)
if not user:
raise UserNotFoundError(username)
software_token_mfa_code = challenge_responses.get("SOFTWARE_TOKEN_MFA_CODE")
if not software_token_mfa_code:
raise ResourceNotFoundError(software_token_mfa_code)
if client.generate_secret:
secret_hash = challenge_responses.get("SECRET_HASH")
if not check_secret_hash(
client.secret, client.id, username, secret_hash
):
raise NotAuthorizedError(secret_hash)
del self.sessions[session]
return self._log_user_in(user_pool, client, username)
else:
return {}
@ -754,6 +861,187 @@ class CognitoIdpBackend(BaseBackend):
user = user_pool.users[username]
user.update_attributes(attributes)
def create_resource_server(self, user_pool_id, identifier, name, scopes):
user_pool = self.user_pools.get(user_pool_id)
if not user_pool:
raise ResourceNotFoundError(user_pool_id)
if identifier in user_pool.resource_servers:
raise InvalidParameterException(
"%s already exists in user pool %s." % (identifier, user_pool_id)
)
resource_server = CognitoResourceServer(user_pool_id, identifier, name, scopes)
user_pool.resource_servers[identifier] = resource_server
return resource_server
def sign_up(self, client_id, username, password, attributes):
user_pool = None
for p in self.user_pools.values():
if client_id in p.clients:
user_pool = p
if user_pool is None:
raise ResourceNotFoundError(client_id)
user = CognitoIdpUser(
user_pool_id=user_pool.id,
username=username,
password=password,
attributes=attributes,
status=UserStatus["UNCONFIRMED"],
)
user_pool.users[user.username] = user
return user
def confirm_sign_up(self, client_id, username, confirmation_code):
user_pool = None
for p in self.user_pools.values():
if client_id in p.clients:
user_pool = p
if user_pool is None:
raise ResourceNotFoundError(client_id)
if username not in user_pool.users:
raise UserNotFoundError(username)
user = user_pool.users[username]
user.status = UserStatus["CONFIRMED"]
return ""
def initiate_auth(self, client_id, auth_flow, auth_parameters):
user_pool = None
for p in self.user_pools.values():
if client_id in p.clients:
user_pool = p
if user_pool is None:
raise ResourceNotFoundError(client_id)
client = p.clients.get(client_id)
if auth_flow == "USER_SRP_AUTH":
username = auth_parameters.get("USERNAME")
srp_a = auth_parameters.get("SRP_A")
if not srp_a:
raise ResourceNotFoundError(srp_a)
if client.generate_secret:
secret_hash = auth_parameters.get("SECRET_HASH")
if not check_secret_hash(
client.secret, client.id, username, secret_hash
):
raise NotAuthorizedError(secret_hash)
user = user_pool.users.get(username)
if not user:
raise UserNotFoundError(username)
if user.status == UserStatus["UNCONFIRMED"]:
raise UserNotConfirmedException("User is not confirmed.")
session = str(uuid.uuid4())
self.sessions[session] = user_pool
return {
"ChallengeName": "PASSWORD_VERIFIER",
"Session": session,
"ChallengeParameters": {
"SALT": str(uuid.uuid4()),
"SRP_B": str(uuid.uuid4()),
"USERNAME": user.id,
"USER_ID_FOR_SRP": user.id,
"SECRET_BLOCK": session,
},
}
elif auth_flow == "REFRESH_TOKEN":
refresh_token = auth_parameters.get("REFRESH_TOKEN")
if not refresh_token:
raise ResourceNotFoundError(refresh_token)
client_id, username = user_pool.refresh_tokens[refresh_token]
if not username:
raise ResourceNotFoundError(username)
if client.generate_secret:
secret_hash = auth_parameters.get("SECRET_HASH")
if not check_secret_hash(
client.secret, client.id, username, secret_hash
):
raise NotAuthorizedError(secret_hash)
(
id_token,
access_token,
expires_in,
) = user_pool.create_tokens_from_refresh_token(refresh_token)
return {
"AuthenticationResult": {
"IdToken": id_token,
"AccessToken": access_token,
"ExpiresIn": expires_in,
}
}
else:
return None
def associate_software_token(self, access_token):
for user_pool in self.user_pools.values():
if access_token in user_pool.access_tokens:
_, username = user_pool.access_tokens[access_token]
user = user_pool.users.get(username)
if not user:
raise UserNotFoundError(username)
return {"SecretCode": str(uuid.uuid4())}
else:
raise NotAuthorizedError(access_token)
def verify_software_token(self, access_token, user_code):
for user_pool in self.user_pools.values():
if access_token in user_pool.access_tokens:
_, username = user_pool.access_tokens[access_token]
user = user_pool.users.get(username)
if not user:
raise UserNotFoundError(username)
user.token_verified = True
return {"Status": "SUCCESS"}
else:
raise NotAuthorizedError(access_token)
def set_user_mfa_preference(
self, access_token, software_token_mfa_settings, sms_mfa_settings
):
for user_pool in self.user_pools.values():
if access_token in user_pool.access_tokens:
_, username = user_pool.access_tokens[access_token]
user = user_pool.users.get(username)
if not user:
raise UserNotFoundError(username)
if software_token_mfa_settings["Enabled"]:
if user.token_verified:
user.software_token_mfa_enabled = True
else:
raise InvalidParameterException(
"User has not verified software token mfa"
)
elif sms_mfa_settings["Enabled"]:
user.sms_mfa_enabled = True
return None
else:
raise NotAuthorizedError(access_token)
def admin_set_user_password(self, user_pool_id, username, password, permanent):
user = self.admin_get_user(user_pool_id, username)
user.password = password
if permanent:
user.status = UserStatus["CONFIRMED"]
else:
user.status = UserStatus["FORCE_CHANGE_PASSWORD"]
cognitoidp_backends = {}
for region in Session().get_available_regions("cognito-idp"):
@ -778,5 +1066,7 @@ def find_region_by_value(key, value):
if key == "access_token" and value in user_pool.access_tokens:
return region
return cognitoidp_backends.keys()[0]
# If we can't find the `client_id` or `access_token`, we just pass
# back a default backend region, which will raise the appropriate
# error message (e.g. NotAuthorized or NotFound).
return list(cognitoidp_backends)[0]

View File

@ -4,7 +4,7 @@ import json
import os
from moto.core.responses import BaseResponse
from .models import cognitoidp_backends, find_region_by_value
from .models import cognitoidp_backends, find_region_by_value, UserStatus
class CognitoIdpResponse(BaseResponse):
@ -84,8 +84,9 @@ class CognitoIdpResponse(BaseResponse):
# User pool client
def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId")
generate_secret = self.parameters.pop("GenerateSecret", False)
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(
user_pool_id, self.parameters
user_pool_id, generate_secret, self.parameters
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
@ -286,7 +287,7 @@ class CognitoIdpResponse(BaseResponse):
user_pool_id, limit=limit, pagination_token=token
)
if filt:
name, value = filt.replace('"', "").split("=")
name, value = filt.replace('"', "").replace(" ", "").split("=")
users = [
user
for user in users
@ -378,6 +379,86 @@ class CognitoIdpResponse(BaseResponse):
)
return ""
# Resource Server
def create_resource_server(self):
user_pool_id = self._get_param("UserPoolId")
identifier = self._get_param("Identifier")
name = self._get_param("Name")
scopes = self._get_param("Scopes")
resource_server = cognitoidp_backends[self.region].create_resource_server(
user_pool_id, identifier, name, scopes
)
return json.dumps({"ResourceServer": resource_server.to_json()})
def sign_up(self):
client_id = self._get_param("ClientId")
username = self._get_param("Username")
password = self._get_param("Password")
user = cognitoidp_backends[self.region].sign_up(
client_id=client_id,
username=username,
password=password,
attributes=self._get_param("UserAttributes", []),
)
return json.dumps(
{
"UserConfirmed": user.status == UserStatus["CONFIRMED"],
"UserSub": user.id,
}
)
def confirm_sign_up(self):
client_id = self._get_param("ClientId")
username = self._get_param("Username")
confirmation_code = self._get_param("ConfirmationCode")
cognitoidp_backends[self.region].confirm_sign_up(
client_id=client_id, username=username, confirmation_code=confirmation_code,
)
return ""
def initiate_auth(self):
client_id = self._get_param("ClientId")
auth_flow = self._get_param("AuthFlow")
auth_parameters = self._get_param("AuthParameters")
auth_result = cognitoidp_backends[self.region].initiate_auth(
client_id, auth_flow, auth_parameters
)
return json.dumps(auth_result)
def associate_software_token(self):
access_token = self._get_param("AccessToken")
result = cognitoidp_backends[self.region].associate_software_token(access_token)
return json.dumps(result)
def verify_software_token(self):
access_token = self._get_param("AccessToken")
user_code = self._get_param("UserCode")
result = cognitoidp_backends[self.region].verify_software_token(
access_token, user_code
)
return json.dumps(result)
def set_user_mfa_preference(self):
access_token = self._get_param("AccessToken")
software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings")
sms_mfa_settings = self._get_param("SMSMfaSettings")
cognitoidp_backends[self.region].set_user_mfa_preference(
access_token, software_token_mfa_settings, sms_mfa_settings
)
return ""
def admin_set_user_password(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
password = self._get_param("Password")
permanent = self._get_param("Permanent")
cognitoidp_backends[self.region].admin_set_user_password(
user_pool_id, username, password, permanent
)
return ""
class CognitoIdpJsonWebKeyResponse(BaseResponse):
def __init__(self):

View File

@ -5,5 +5,5 @@ url_bases = ["https?://cognito-idp.(.+).amazonaws.com"]
url_paths = {
"{0}/$": CognitoIdpResponse.dispatch,
"{0}/<user_pool_id>/.well-known/jwks.json$": CognitoIdpJsonWebKeyResponse().serve_json_web_key,
"{0}/(?P<user_pool_id>[^/]+)/.well-known/jwks.json$": CognitoIdpJsonWebKeyResponse().serve_json_web_key,
}

21
moto/cognitoidp/utils.py Normal file
View File

@ -0,0 +1,21 @@
from __future__ import unicode_literals
import six
import random
import string
import hashlib
import hmac
import base64
def create_id():
size = 26
chars = list(range(10)) + list(string.ascii_lowercase)
return "".join(six.text_type(random.choice(chars)) for x in range(size))
def check_secret_hash(app_client_secret, app_client_id, username, secret_hash):
key = bytes(str(app_client_secret).encode("latin-1"))
msg = bytes(str(username + app_client_id).encode("latin-1"))
new_digest = hmac.new(key, msg, hashlib.sha256).digest()
SECRET_HASH = base64.b64encode(new_digest).decode()
return SECRET_HASH == secret_hash

View File

@ -366,3 +366,29 @@ class TooManyResourceKeys(JsonRESTError):
message = str(message)
super(TooManyResourceKeys, self).__init__("ValidationException", message)
class InvalidResultTokenException(JsonRESTError):
code = 400
def __init__(self):
message = "The resultToken provided is invalid"
super(InvalidResultTokenException, self).__init__(
"InvalidResultTokenException", message
)
class ValidationException(JsonRESTError):
code = 400
def __init__(self, message):
super(ValidationException, self).__init__("ValidationException", message)
class NoSuchOrganizationConformancePackException(JsonRESTError):
code = 400
def __init__(self, message):
super(NoSuchOrganizationConformancePackException, self).__init__(
"NoSuchOrganizationConformancePackException", message
)

View File

@ -40,13 +40,17 @@ from moto.config.exceptions import (
TooManyResourceIds,
ResourceNotDiscoveredException,
TooManyResourceKeys,
InvalidResultTokenException,
ValidationException,
NoSuchOrganizationConformancePackException,
)
from moto.core import BaseBackend, BaseModel
from moto.s3.config import s3_config_query
from moto.s3.config import s3_account_public_access_block_query, s3_config_query
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
from moto.iam.config import role_config_query, policy_config_query
POP_STRINGS = [
"capitalizeStart",
"CapitalizeStart",
@ -58,7 +62,12 @@ POP_STRINGS = [
DEFAULT_PAGE_SIZE = 100
# Map the Config resource type to a backend:
RESOURCE_MAP = {"AWS::S3::Bucket": s3_config_query}
RESOURCE_MAP = {
"AWS::S3::Bucket": s3_config_query,
"AWS::S3::AccountPublicAccessBlock": s3_account_public_access_block_query,
"AWS::IAM::Role": role_config_query,
"AWS::IAM::Policy": policy_config_query,
}
def datetime2int(date):
@ -155,7 +164,8 @@ class ConfigEmptyDictable(BaseModel):
def to_dict(self):
data = {}
for item, value in self.__dict__.items():
if value is not None:
# ignore private attributes
if not item.startswith("_") and value is not None:
if isinstance(value, ConfigEmptyDictable):
data[
snake_to_camels(
@ -363,12 +373,56 @@ class ConfigAggregationAuthorization(ConfigEmptyDictable):
self.tags = tags or {}
class OrganizationConformancePack(ConfigEmptyDictable):
def __init__(
self,
region,
name,
delivery_s3_bucket,
delivery_s3_key_prefix=None,
input_parameters=None,
excluded_accounts=None,
):
super(OrganizationConformancePack, self).__init__(
capitalize_start=True, capitalize_arn=False
)
self._status = "CREATE_SUCCESSFUL"
self._unique_pack_name = "{0}-{1}".format(name, random_string())
self.conformance_pack_input_parameters = input_parameters or []
self.delivery_s3_bucket = delivery_s3_bucket
self.delivery_s3_key_prefix = delivery_s3_key_prefix
self.excluded_accounts = excluded_accounts or []
self.last_update_time = datetime2int(datetime.utcnow())
self.organization_conformance_pack_arn = "arn:aws:config:{0}:{1}:organization-conformance-pack/{2}".format(
region, DEFAULT_ACCOUNT_ID, self._unique_pack_name
)
self.organization_conformance_pack_name = name
def update(
self,
delivery_s3_bucket,
delivery_s3_key_prefix,
input_parameters,
excluded_accounts,
):
self._status = "UPDATE_SUCCESSFUL"
self.conformance_pack_input_parameters = input_parameters
self.delivery_s3_bucket = delivery_s3_bucket
self.delivery_s3_key_prefix = delivery_s3_key_prefix
self.excluded_accounts = excluded_accounts
self.last_update_time = datetime2int(datetime.utcnow())
class ConfigBackend(BaseBackend):
def __init__(self):
self.recorders = {}
self.delivery_channels = {}
self.config_aggregators = {}
self.aggregation_authorizations = {}
self.organization_conformance_packs = {}
@staticmethod
def _validate_resource_types(resource_list):
@ -867,16 +921,17 @@ class ConfigBackend(BaseBackend):
backend_region=backend_query_region,
)
result = {
"resourceIdentifiers": [
{
"resourceType": identifier["type"],
"resourceId": identifier["id"],
"resourceName": identifier["name"],
}
for identifier in identifiers
]
}
resource_identifiers = []
for identifier in identifiers:
item = {"resourceType": identifier["type"], "resourceId": identifier["id"]}
# Some resource types lack names:
if identifier.get("name"):
item["resourceName"] = identifier["name"]
resource_identifiers.append(item)
result = {"resourceIdentifiers": resource_identifiers}
if new_token:
result["nextToken"] = new_token
@ -925,20 +980,23 @@ class ConfigBackend(BaseBackend):
limit,
next_token,
resource_region=resource_region,
aggregator=self.config_aggregators.get(aggregator_name).__dict__,
)
result = {
"ResourceIdentifiers": [
{
"SourceAccountId": DEFAULT_ACCOUNT_ID,
"SourceRegion": identifier["region"],
"ResourceType": identifier["type"],
"ResourceId": identifier["id"],
"ResourceName": identifier["name"],
}
for identifier in identifiers
]
}
resource_identifiers = []
for identifier in identifiers:
item = {
"SourceAccountId": DEFAULT_ACCOUNT_ID,
"SourceRegion": identifier["region"],
"ResourceType": identifier["type"],
"ResourceId": identifier["id"],
}
if identifier.get("name"):
item["ResourceName"] = identifier["name"]
resource_identifiers.append(item)
result = {"ResourceIdentifiers": resource_identifiers}
if new_token:
result["NextToken"] = new_token
@ -948,9 +1006,9 @@ class ConfigBackend(BaseBackend):
def get_resource_config_history(self, resource_type, id, backend_region):
"""Returns the configuration of an item in the AWS Config format of the resource for the current regional backend.
NOTE: This is --NOT-- returning history as it is not supported in moto at this time. (PR's welcome!)
As such, the later_time, earlier_time, limit, and next_token are ignored as this will only
return 1 item. (If no items, it raises an exception)
NOTE: This is --NOT-- returning history as it is not supported in moto at this time. (PR's welcome!)
As such, the later_time, earlier_time, limit, and next_token are ignored as this will only
return 1 item. (If no items, it raises an exception)
"""
# If the type isn't implemented then we won't find the item:
if resource_type not in RESOURCE_MAP:
@ -1032,10 +1090,10 @@ class ConfigBackend(BaseBackend):
):
"""Returns the configuration of an item in the AWS Config format of the resource for the current regional backend.
As far a moto goes -- the only real difference between this function and the `batch_get_resource_config` function is that
this will require a Config Aggregator be set up a priori and can search based on resource regions.
As far a moto goes -- the only real difference between this function and the `batch_get_resource_config` function is that
this will require a Config Aggregator be set up a priori and can search based on resource regions.
Note: moto will IGNORE the resource account ID in the search query.
Note: moto will IGNORE the resource account ID in the search query.
"""
if not self.config_aggregators.get(aggregator_name):
raise NoSuchConfigurationAggregatorException()
@ -1082,6 +1140,154 @@ class ConfigBackend(BaseBackend):
"UnprocessedResourceIdentifiers": not_found,
}
def put_evaluations(self, evaluations=None, result_token=None, test_mode=False):
if not evaluations:
raise InvalidParameterValueException(
"The Evaluations object in your request cannot be null."
"Add the required parameters and try again."
)
if not result_token:
raise InvalidResultTokenException()
# Moto only supports PutEvaluations with test mode currently (missing rule and token support)
if not test_mode:
raise NotImplementedError(
"PutEvaluations without TestMode is not yet implemented"
)
return {
"FailedEvaluations": [],
} # At this time, moto is not adding failed evaluations.
def put_organization_conformance_pack(
self,
region,
name,
template_s3_uri,
template_body,
delivery_s3_bucket,
delivery_s3_key_prefix,
input_parameters,
excluded_accounts,
):
# a real validation of the content of the template is missing at the moment
if not template_s3_uri and not template_body:
raise ValidationException("Template body is invalid")
if not re.match(r"s3://.*", template_s3_uri):
raise ValidationException(
"1 validation error detected: "
"Value '{}' at 'templateS3Uri' failed to satisfy constraint: "
"Member must satisfy regular expression pattern: "
"s3://.*".format(template_s3_uri)
)
pack = self.organization_conformance_packs.get(name)
if pack:
pack.update(
delivery_s3_bucket=delivery_s3_bucket,
delivery_s3_key_prefix=delivery_s3_key_prefix,
input_parameters=input_parameters,
excluded_accounts=excluded_accounts,
)
else:
pack = OrganizationConformancePack(
region=region,
name=name,
delivery_s3_bucket=delivery_s3_bucket,
delivery_s3_key_prefix=delivery_s3_key_prefix,
input_parameters=input_parameters,
excluded_accounts=excluded_accounts,
)
self.organization_conformance_packs[name] = pack
return {
"OrganizationConformancePackArn": pack.organization_conformance_pack_arn
}
def describe_organization_conformance_packs(self, names):
packs = []
for name in names:
pack = self.organization_conformance_packs.get(name)
if not pack:
raise NoSuchOrganizationConformancePackException(
"One or more organization conformance packs with specified names are not present. "
"Ensure your names are correct and try your request again later."
)
packs.append(pack.to_dict())
return {"OrganizationConformancePacks": packs}
def describe_organization_conformance_pack_statuses(self, names):
packs = []
statuses = []
if names:
for name in names:
pack = self.organization_conformance_packs.get(name)
if not pack:
raise NoSuchOrganizationConformancePackException(
"One or more organization conformance packs with specified names are not present. "
"Ensure your names are correct and try your request again later."
)
packs.append(pack)
else:
packs = list(self.organization_conformance_packs.values())
for pack in packs:
statuses.append(
{
"OrganizationConformancePackName": pack.organization_conformance_pack_name,
"Status": pack._status,
"LastUpdateTime": pack.last_update_time,
}
)
return {"OrganizationConformancePackStatuses": statuses}
def get_organization_conformance_pack_detailed_status(self, name):
pack = self.organization_conformance_packs.get(name)
if not pack:
raise NoSuchOrganizationConformancePackException(
"One or more organization conformance packs with specified names are not present. "
"Ensure your names are correct and try your request again later."
)
# actually here would be a list of all accounts in the organization
statuses = [
{
"AccountId": DEFAULT_ACCOUNT_ID,
"ConformancePackName": "OrgConformsPack-{0}".format(
pack._unique_pack_name
),
"Status": pack._status,
"LastUpdateTime": datetime2int(datetime.utcnow()),
}
]
return {"OrganizationConformancePackDetailedStatuses": statuses}
def delete_organization_conformance_pack(self, name):
pack = self.organization_conformance_packs.get(name)
if not pack:
raise NoSuchOrganizationConformancePackException(
"Could not find an OrganizationConformancePack for given request with resourceName {}".format(
name
)
)
self.organization_conformance_packs.pop(name)
config_backends = {}
for region in Session().get_available_regions("config"):

View File

@ -151,3 +151,54 @@ class ConfigResponse(BaseResponse):
self._get_param("ResourceIdentifiers"),
)
return json.dumps(schema)
def put_evaluations(self):
evaluations = self.config_backend.put_evaluations(
self._get_param("Evaluations"),
self._get_param("ResultToken"),
self._get_param("TestMode"),
)
return json.dumps(evaluations)
def put_organization_conformance_pack(self):
conformance_pack = self.config_backend.put_organization_conformance_pack(
region=self.region,
name=self._get_param("OrganizationConformancePackName"),
template_s3_uri=self._get_param("TemplateS3Uri"),
template_body=self._get_param("TemplateBody"),
delivery_s3_bucket=self._get_param("DeliveryS3Bucket"),
delivery_s3_key_prefix=self._get_param("DeliveryS3KeyPrefix"),
input_parameters=self._get_param("ConformancePackInputParameters"),
excluded_accounts=self._get_param("ExcludedAccounts"),
)
return json.dumps(conformance_pack)
def describe_organization_conformance_packs(self):
conformance_packs = self.config_backend.describe_organization_conformance_packs(
self._get_param("OrganizationConformancePackNames")
)
return json.dumps(conformance_packs)
def describe_organization_conformance_pack_statuses(self):
statuses = self.config_backend.describe_organization_conformance_pack_statuses(
self._get_param("OrganizationConformancePackNames")
)
return json.dumps(statuses)
def get_organization_conformance_pack_detailed_status(self):
# 'Filters' parameter is not implemented yet
statuses = self.config_backend.get_organization_conformance_pack_detailed_status(
self._get_param("OrganizationConformancePackName")
)
return json.dumps(statuses)
def delete_organization_conformance_pack(self):
self.config_backend.delete_organization_conformance_pack(
self._get_param("OrganizationConformancePackName")
)
return ""

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
from .models import BaseModel, BaseBackend, moto_api_backend, ACCOUNT_ID # noqa
from .models import CloudFormationModel # noqa
from .responses import ActionAuthenticatorMixin
moto_api_backends = {"global": moto_api_backend}

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment
import json
SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
@ -109,6 +110,22 @@ class AuthFailureError(RESTError):
)
class AWSError(Exception):
TYPE = None
STATUS = 400
def __init__(self, message, type=None, status=None):
self.message = message
self.type = type if type is not None else self.TYPE
self.status = status if status is not None else self.STATUS
def response(self):
return (
json.dumps({"__type": self.type, "message": self.message}),
dict(status=self.status),
)
class InvalidNextTokenException(JsonRESTError):
"""For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core."""

View File

@ -5,12 +5,19 @@ from __future__ import absolute_import
import functools
import inspect
import os
import pkg_resources
import re
import six
import types
from abc import abstractmethod
from io import BytesIO
from collections import defaultdict
from botocore.config import Config
from botocore.handlers import BUILTIN_HANDLERS
from botocore.awsrequest import AWSResponse
from distutils.version import LooseVersion
from six.moves.urllib.parse import urlparse
from werkzeug.wrappers import Request
import mock
from moto import settings
@ -22,22 +29,23 @@ from .utils import (
convert_flask_to_responses_response,
)
ACCOUNT_ID = os.environ.get("MOTO_ACCOUNT_ID", "123456789012")
RESPONSES_VERSION = pkg_resources.get_distribution("responses").version
class BaseMockAWS(object):
nested_count = 0
def __init__(self, backends):
from moto.instance_metadata import instance_metadata_backend
from moto.core import moto_api_backend
self.backends = backends
self.backends_for_urls = {}
from moto.backends import BACKENDS
default_backends = {
"instance_metadata": BACKENDS["instance_metadata"]["global"],
"moto_api": BACKENDS["moto_api"]["global"],
"instance_metadata": instance_metadata_backend,
"moto_api": moto_api_backend,
}
self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends)
@ -174,6 +182,28 @@ class CallbackResponse(responses.CallbackResponse):
"""
Need to override this so we can pass decode_content=False
"""
if not isinstance(request, Request):
url = urlparse(request.url)
if request.body is None:
body = None
elif isinstance(request.body, six.text_type):
body = six.BytesIO(six.b(request.body))
elif hasattr(request.body, "read"):
body = six.BytesIO(request.body.read())
else:
body = six.BytesIO(request.body)
req = Request.from_values(
path="?".join([url.path, url.query]),
input_stream=body,
content_length=request.headers.get("Content-Length"),
content_type=request.headers.get("Content-Type"),
method=request.method,
base_url="{scheme}://{netloc}".format(
scheme=url.scheme, netloc=url.netloc
),
headers=[(k, v) for k, v in six.iteritems(request.headers)],
)
request = req
headers = self.get_headers()
result = self.callback(request)
@ -217,12 +247,46 @@ botocore_mock = responses.RequestsMock(
assert_all_requests_are_fired=False,
target="botocore.vendored.requests.adapters.HTTPAdapter.send",
)
responses_mock = responses._default_mock
# Add passthrough to allow any other requests to work
# Since this uses .startswith, it applies to http and https requests.
responses_mock.add_passthru("http")
def _find_first_match_legacy(self, request):
for i, match in enumerate(self._matches):
if match.matches(request):
return match
return None
def _find_first_match(self, request):
match_failed_reasons = []
for i, match in enumerate(self._matches):
match_result, reason = match.matches(request)
if match_result:
return match, match_failed_reasons
else:
match_failed_reasons.append(reason)
return None, match_failed_reasons
# Modify behaviour of the matcher to only/always return the first match
# Default behaviour is to return subsequent matches for subsequent requests, which leads to https://github.com/spulec/moto/issues/2567
# - First request matches on the appropriate S3 URL
# - Same request, executed again, will be matched on the subsequent match, which happens to be the catch-all, not-yet-implemented, callback
# Fix: Always return the first match
if LooseVersion(RESPONSES_VERSION) < LooseVersion("0.12.1"):
responses_mock._find_match = types.MethodType(
_find_first_match_legacy, responses_mock
)
else:
responses_mock._find_match = types.MethodType(_find_first_match, responses_mock)
BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
@ -329,7 +393,7 @@ class BotocoreEventMockAWS(BaseMockAWS):
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile("https?://.+.amazonaws.com/.*"),
url=re.compile(r"https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
@ -338,7 +402,7 @@ class BotocoreEventMockAWS(BaseMockAWS):
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile("https?://.+.amazonaws.com/.*"),
url=re.compile(r"https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
@ -373,6 +437,13 @@ class ServerModeMockAWS(BaseMockAWS):
import mock
def fake_boto3_client(*args, **kwargs):
region = self._get_region(*args, **kwargs)
if region:
if "config" in kwargs:
kwargs["config"].__dict__["user_agent_extra"] += " region/" + region
else:
config = Config(user_agent_extra="region/" + region)
kwargs["config"] = config
if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_client(*args, **kwargs)
@ -420,6 +491,14 @@ class ServerModeMockAWS(BaseMockAWS):
if six.PY2:
self._httplib_patcher.start()
def _get_region(self, *args, **kwargs):
if "region_name" in kwargs:
return kwargs["region_name"]
if type(args) == tuple and len(args) == 2:
service, region = args
return region
return None
def disable_patching(self):
if self._client_patcher:
self._client_patcher.stop()
@ -475,6 +554,56 @@ class BaseModel(object):
return instance
# Parent class for every Model that can be instantiated by CloudFormation
# On subclasses, implement the two methods as @staticmethod to ensure correct behaviour of the CF parser
class CloudFormationModel(BaseModel):
@staticmethod
@abstractmethod
def cloudformation_name_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-name.html
# This must be implemented as a staticmethod with no parameters
# Return None for resources that do not have a name property
pass
@staticmethod
@abstractmethod
def cloudformation_type():
# This must be implemented as a staticmethod with no parameters
# See for example https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html
return "AWS::SERVICE::RESOURCE"
@abstractmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
# This must be implemented as a classmethod with parameters:
# cls, resource_name, cloudformation_json, region_name
# Extract the resource parameters from the cloudformation json
# and return an instance of the resource class
pass
@abstractmethod
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
# This must be implemented as a classmethod with parameters:
# cls, original_resource, new_resource_name, cloudformation_json, region_name
# Extract the resource parameters from the cloudformation json,
# delete the old resource and return the new one. Optionally inspect
# the change in parameters and no-op when nothing has changed.
pass
@abstractmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
# This must be implemented as a classmethod with parameters:
# cls, resource_name, cloudformation_json, region_name
# Extract the resource parameters from the cloudformation json
# and delete the resource. Do not include a return statement.
pass
class BaseBackend(object):
def _reset_model_refs(self):
# Remove all references to the models stored
@ -582,6 +711,7 @@ class ConfigQueryModel(object):
next_token,
backend_region=None,
resource_region=None,
aggregator=None,
):
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region.
@ -606,12 +736,17 @@ class ConfigQueryModel(object):
As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter
from there. It may be valuable to make this a concatenation of the region and resource name.
:param resource_region:
:param resource_ids:
:param resource_name:
:param limit:
:param next_token:
:param resource_ids: A list of resource IDs
:param resource_name: The individual name of a resource
:param limit: How many per page
:param next_token: The item that will page on
:param backend_region: The region for the backend to pull results from. Set to `None` if this is an aggregated query.
:param resource_region: The region for where the resources reside to pull results from. Set to `None` if this is a
non-aggregated query.
:param aggregator: If the query is an aggregated query, *AND* the resource has "non-standard" aggregation logic (mainly, IAM),
you'll need to pass aggregator used. In most cases, this should be omitted/set to `None`. See the
conditional logic under `if aggregator` in the moto/iam/config.py for the IAM example.
:return: This should return a list of Dicts that have the following fields:
[
{
@ -680,12 +815,12 @@ class deprecated_base_decorator(base_decorator):
class MotoAPIBackend(BaseBackend):
def reset(self):
from moto.backends import BACKENDS
import moto.backends as backends
for name, backends in BACKENDS.items():
for name, backends_ in backends.named_backends():
if name == "moto_api":
continue
for region_name, backend in backends.items():
for region_name, backend in backends_.items():
backend.reset()
self.__init__()

View File

@ -11,16 +11,14 @@ import requests
import pytz
from moto.core.access_control import IAMRequest, S3IAMRequest
from moto.core.exceptions import DryRunClientError
from jinja2 import Environment, DictLoader, TemplateNotFound
import six
from six.moves.urllib.parse import parse_qs, urlparse
from six.moves.urllib.parse import parse_qs, parse_qsl, urlparse
import xmltodict
from pkg_resources import resource_filename
from werkzeug.exceptions import HTTPException
import boto3
@ -32,7 +30,7 @@ log = logging.getLogger(__name__)
def _decode_dict(d):
decoded = {}
decoded = OrderedDict()
for key, value in d.items():
if isinstance(key, six.binary_type):
newkey = key.decode("utf-8")
@ -64,9 +62,9 @@ def _decode_dict(d):
class DynamicDictLoader(DictLoader):
"""
Note: There's a bug in jinja2 pre-2.7.3 DictLoader where caching does not work.
Including the fixed (current) method version here to ensure performance benefit
even for those using older jinja versions.
Note: There's a bug in jinja2 pre-2.7.3 DictLoader where caching does not work.
Including the fixed (current) method version here to ensure performance benefit
even for those using older jinja versions.
"""
def get_source(self, environment, template):
@ -135,9 +133,13 @@ class ActionAuthenticatorMixin(object):
ActionAuthenticatorMixin.request_count += 1
def _authenticate_and_authorize_normal_action(self):
from moto.iam.access_control import IAMRequest
self._authenticate_and_authorize_action(IAMRequest)
def _authenticate_and_authorize_s3_action(self):
from moto.iam.access_control import S3IAMRequest
self._authenticate_and_authorize_action(S3IAMRequest)
@staticmethod
@ -186,6 +188,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = "us-east-1"
# to extract region, use [^.]
region_regex = re.compile(r"\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com")
region_from_useragent_regex = re.compile(
r"region/(?P<region>[a-z]{2}-[a-z]+-\d{1})"
)
param_list_regex = re.compile(r"(.*)\.(\d+)\.")
access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
@ -197,7 +202,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return cls()._dispatch(*args, **kwargs)
def setup_class(self, request, full_url, headers):
querystring = {}
querystring = OrderedDict()
if hasattr(request, "body"):
# Boto
self.body = request.body
@ -209,7 +214,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# definition for back-compatibility
self.body = request.data
querystring = {}
querystring = OrderedDict()
for key, value in request.form.items():
querystring[key] = [value]
@ -238,7 +243,14 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
querystring[key] = [value]
elif self.body:
try:
querystring.update(parse_qs(raw_body, keep_blank_values=True))
querystring.update(
OrderedDict(
(key, [value])
for key, value in parse_qsl(
raw_body, keep_blank_values=True
)
)
)
except UnicodeEncodeError:
pass # ignore encoding errors, as the body may not contain a legitimate querystring
if not querystring:
@ -263,9 +275,14 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.response_headers = {"server": "amazon.com"}
def get_region_from_url(self, request, full_url):
match = self.region_regex.search(full_url)
if match:
region = match.group(1)
url_match = self.region_regex.search(full_url)
user_agent_match = self.region_from_useragent_regex.search(
request.headers.get("User-Agent", "")
)
if url_match:
region = url_match.group(1)
elif user_agent_match:
region = user_agent_match.group(1)
elif (
"Authorization" in request.headers
and "AWS4" in request.headers["Authorization"]
@ -521,8 +538,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
returns
{
"SlaveInstanceType": "m1.small",
"InstanceCount": "1",
"slave_instance_type": "m1.small",
"instance_count": "1",
}
"""
params = {}
@ -766,6 +783,9 @@ class AWSServiceSpec(object):
"""
def __init__(self, path):
# Importing pkg_resources takes ~60ms; keep it local
from pkg_resources import resource_filename # noqa
self.path = resource_filename("botocore", path)
with io.open(self.path, "r", encoding="utf-8") as f:
spec = json.load(f)

View File

@ -16,7 +16,7 @@ REQUEST_ID_LONG = string.digits + string.ascii_uppercase
def camelcase_to_underscores(argument):
""" Converts a camelcase param like theNewAttribute to the equivalent
"""Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute"""
result = ""
prev_char_title = True
@ -42,9 +42,9 @@ def camelcase_to_underscores(argument):
def underscores_to_camelcase(argument):
""" Converts a camelcase param like the_new_attribute to the equivalent
"""Converts a camelcase param like the_new_attribute to the equivalent
camelcase version like theNewAttribute. Note that the first letter is
NOT capitalized by this function """
NOT capitalized by this function"""
result = ""
previous_was_underscore = False
for char in argument:
@ -57,6 +57,11 @@ def underscores_to_camelcase(argument):
return result
def pascal_to_camelcase(argument):
"""Converts a PascalCase param to the camelCase equivalent"""
return argument[0].lower() + argument[1:]
def method_names_from_class(clazz):
# On Python 2, methods are different from functions, and the `inspect`
# predicates distinguish between them. On Python 3, methods are just
@ -95,7 +100,7 @@ def convert_regex_to_flask_path(url_path):
match_name, match_pattern = reg.groups()
return '<regex("{0}"):{1}>'.format(match_pattern, match_name)
url_path = re.sub("\(\?P<(.*?)>(.*?)\)", caller, url_path)
url_path = re.sub(r"\(\?P<(.*?)>(.*?)\)", caller, url_path)
if url_path.endswith("/?"):
# Flask does own handling of trailing slashes
@ -187,7 +192,13 @@ def iso_8601_datetime_with_milliseconds(datetime):
def iso_8601_datetime_without_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
return None if datetime is None else datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
def iso_8601_datetime_without_milliseconds_s3(datetime):
return (
None if datetime is None else datetime.strftime("%Y-%m-%dT%H:%M:%S.000") + "Z"
)
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
@ -328,3 +339,63 @@ def py2_strip_unicode_keys(blob):
blob = new_set
return blob
def tags_from_query_string(
querystring_dict, prefix="Tag", key_suffix="Key", value_suffix="Value"
):
response_values = {}
for key, value in querystring_dict.items():
if key.startswith(prefix) and key.endswith(key_suffix):
tag_index = key.replace(prefix + ".", "").replace("." + key_suffix, "")
tag_key = querystring_dict.get(
"{prefix}.{index}.{key_suffix}".format(
prefix=prefix, index=tag_index, key_suffix=key_suffix,
)
)[0]
tag_value_key = "{prefix}.{index}.{value_suffix}".format(
prefix=prefix, index=tag_index, value_suffix=value_suffix,
)
if tag_value_key in querystring_dict:
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
else:
response_values[tag_key] = None
return response_values
def tags_from_cloudformation_tags_list(tags_list):
"""Return tags in dict form from cloudformation resource tags form (list of dicts)"""
tags = {}
for entry in tags_list:
key = entry["Key"]
value = entry["Value"]
tags[key] = value
return tags
def remap_nested_keys(root, key_transform):
"""This remap ("recursive map") function is used to traverse and
transform the dictionary keys of arbitrarily nested structures.
List comprehensions do not recurse, making it tedious to apply
transforms to all keys in a tree-like structure.
A common issue for `moto` is changing the casing of dict keys:
>>> remap_nested_keys({'KeyName': 'Value'}, camelcase_to_underscores)
{'key_name': 'Value'}
Args:
root: The target data to traverse. Supports iterables like
:class:`list`, :class:`tuple`, and :class:`dict`.
key_transform (callable): This function is called on every
dictionary key found in *root*.
"""
if isinstance(root, (list, tuple)):
return [remap_nested_keys(item, key_transform) for item in root]
if isinstance(root, dict):
return {
key_transform(k): remap_nested_keys(v, key_transform)
for k, v in six.iteritems(root)
}
return root

View File

@ -4,7 +4,7 @@ import datetime
from boto3 import Session
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core import BaseBackend, BaseModel, CloudFormationModel
from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys
@ -18,7 +18,7 @@ class PipelineObject(BaseModel):
return {"fields": self.fields, "id": self.object_id, "name": self.name}
class Pipeline(BaseModel):
class Pipeline(CloudFormationModel):
def __init__(self, name, unique_id, **kwargs):
self.name = name
self.unique_id = unique_id
@ -74,6 +74,15 @@ class Pipeline(BaseModel):
def activate(self):
self.status = "SCHEDULED"
@staticmethod
def cloudformation_name_type():
return "Name"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-datapipeline-pipeline.html
return "AWS::DataPipeline::Pipeline"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
@ -81,9 +90,9 @@ class Pipeline(BaseModel):
datapipeline_backend = datapipeline_backends[region_name]
properties = cloudformation_json["Properties"]
cloudformation_unique_id = "cf-" + properties["Name"]
cloudformation_unique_id = "cf-" + resource_name
pipeline = datapipeline_backend.create_pipeline(
properties["Name"], cloudformation_unique_id
resource_name, cloudformation_unique_id
)
datapipeline_backend.put_pipeline_definition(
pipeline.pipeline_id, properties["PipelineObjects"]

View File

@ -4,7 +4,7 @@ import datetime
import json
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.core.utils import unix_time
from moto.core import ACCOUNT_ID
from .comparisons import get_comparison_func
@ -82,7 +82,7 @@ class Item(BaseModel):
return {"Item": included}
class Table(BaseModel):
class Table(CloudFormationModel):
def __init__(
self,
name,
@ -135,6 +135,15 @@ class Table(BaseModel):
}
return results
@staticmethod
def cloudformation_name_type():
return "TableName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html
return "AWS::DynamoDB::Table"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name

View File

@ -1,5 +1,5 @@
from __future__ import unicode_literals
from .models import dynamodb_backends as dynamodb_backends2
from moto.dynamodb2.models import dynamodb_backends as dynamodb_backends2
from ..core.models import base_decorator, deprecated_base_decorator
dynamodb_backend2 = dynamodb_backends2["us-east-1"]

View File

@ -251,9 +251,9 @@ class ConditionExpressionParser:
def _lex_one_node(self, remaining_expression):
# TODO: Handle indexing like [1]
attribute_regex = "(:|#)?[A-z0-9\-_]+"
attribute_regex = r"(:|#)?[A-z0-9\-_]+"
patterns = [
(self.Nonterminal.WHITESPACE, re.compile("^ +")),
(self.Nonterminal.WHITESPACE, re.compile(r"^ +")),
(
self.Nonterminal.COMPARATOR,
re.compile(
@ -270,12 +270,14 @@ class ConditionExpressionParser:
(
self.Nonterminal.OPERAND,
re.compile(
"^" + attribute_regex + "(\." + attribute_regex + "|\[[0-9]\])*"
r"^{attribute_regex}(\.{attribute_regex}|\[[0-9]\])*".format(
attribute_regex=attribute_regex
)
),
),
(self.Nonterminal.COMMA, re.compile("^,")),
(self.Nonterminal.LEFT_PAREN, re.compile("^\(")),
(self.Nonterminal.RIGHT_PAREN, re.compile("^\)")),
(self.Nonterminal.COMMA, re.compile(r"^,")),
(self.Nonterminal.LEFT_PAREN, re.compile(r"^\(")),
(self.Nonterminal.RIGHT_PAREN, re.compile(r"^\)")),
]
for nonterminal, pattern in patterns:
@ -285,7 +287,7 @@ class ConditionExpressionParser:
break
else: # pragma: no cover
raise ValueError(
"Cannot parse condition starting at: " + remaining_expression
"Cannot parse condition starting at:{}".format(remaining_expression)
)
node = self.Node(
@ -318,7 +320,7 @@ class ConditionExpressionParser:
for child in children:
self._assert(
child.nonterminal == self.Nonterminal.IDENTIFIER,
"Cannot use %s in path" % child.text,
"Cannot use {} in path".format(child.text),
[node],
)
output.append(
@ -392,7 +394,7 @@ class ConditionExpressionParser:
elif name.startswith("["):
# e.g. [123]
if not name.endswith("]"): # pragma: no cover
raise ValueError("Bad path element %s" % name)
raise ValueError("Bad path element {}".format(name))
return self.Node(
nonterminal=self.Nonterminal.IDENTIFIER,
kind=self.Kind.LITERAL,

View File

@ -2,9 +2,172 @@ class InvalidIndexNameError(ValueError):
pass
class InvalidUpdateExpression(ValueError):
pass
class MockValidationException(ValueError):
def __init__(self, message):
self.exception_msg = message
class ItemSizeTooLarge(Exception):
message = "Item size has exceeded the maximum allowed size"
class InvalidUpdateExpressionInvalidDocumentPath(MockValidationException):
invalid_update_expression_msg = (
"The document path provided in the update expression is invalid for update"
)
def __init__(self):
super(InvalidUpdateExpressionInvalidDocumentPath, self).__init__(
self.invalid_update_expression_msg
)
class InvalidUpdateExpression(MockValidationException):
invalid_update_expr_msg = "Invalid UpdateExpression: {update_expression_error}"
def __init__(self, update_expression_error):
self.update_expression_error = update_expression_error
super(InvalidUpdateExpression, self).__init__(
self.invalid_update_expr_msg.format(
update_expression_error=update_expression_error
)
)
class AttributeDoesNotExist(MockValidationException):
attr_does_not_exist_msg = (
"The provided expression refers to an attribute that does not exist in the item"
)
def __init__(self):
super(AttributeDoesNotExist, self).__init__(self.attr_does_not_exist_msg)
class ProvidedKeyDoesNotExist(MockValidationException):
provided_key_does_not_exist_msg = (
"The provided key element does not match the schema"
)
def __init__(self):
super(ProvidedKeyDoesNotExist, self).__init__(
self.provided_key_does_not_exist_msg
)
class ExpressionAttributeNameNotDefined(InvalidUpdateExpression):
name_not_defined_msg = "An expression attribute name used in the document path is not defined; attribute name: {n}"
def __init__(self, attribute_name):
self.not_defined_attribute_name = attribute_name
super(ExpressionAttributeNameNotDefined, self).__init__(
self.name_not_defined_msg.format(n=attribute_name)
)
class AttributeIsReservedKeyword(InvalidUpdateExpression):
attribute_is_keyword_msg = (
"Attribute name is a reserved keyword; reserved keyword: {keyword}"
)
def __init__(self, keyword):
self.keyword = keyword
super(AttributeIsReservedKeyword, self).__init__(
self.attribute_is_keyword_msg.format(keyword=keyword)
)
class ExpressionAttributeValueNotDefined(InvalidUpdateExpression):
attr_value_not_defined_msg = "An expression attribute value used in expression is not defined; attribute value: {attribute_value}"
def __init__(self, attribute_value):
self.attribute_value = attribute_value
super(ExpressionAttributeValueNotDefined, self).__init__(
self.attr_value_not_defined_msg.format(attribute_value=attribute_value)
)
class UpdateExprSyntaxError(InvalidUpdateExpression):
update_expr_syntax_error_msg = "Syntax error; {error_detail}"
def __init__(self, error_detail):
self.error_detail = error_detail
super(UpdateExprSyntaxError, self).__init__(
self.update_expr_syntax_error_msg.format(error_detail=error_detail)
)
class InvalidTokenException(UpdateExprSyntaxError):
token_detail_msg = 'token: "{token}", near: "{near}"'
def __init__(self, token, near):
self.token = token
self.near = near
super(InvalidTokenException, self).__init__(
self.token_detail_msg.format(token=token, near=near)
)
class InvalidExpressionAttributeNameKey(MockValidationException):
invalid_expr_attr_name_msg = (
'ExpressionAttributeNames contains invalid key: Syntax error; key: "{key}"'
)
def __init__(self, key):
self.key = key
super(InvalidExpressionAttributeNameKey, self).__init__(
self.invalid_expr_attr_name_msg.format(key=key)
)
class ItemSizeTooLarge(MockValidationException):
item_size_too_large_msg = "Item size has exceeded the maximum allowed size"
def __init__(self):
super(ItemSizeTooLarge, self).__init__(self.item_size_too_large_msg)
class ItemSizeToUpdateTooLarge(MockValidationException):
item_size_to_update_too_large_msg = (
"Item size to update has exceeded the maximum allowed size"
)
def __init__(self):
super(ItemSizeToUpdateTooLarge, self).__init__(
self.item_size_to_update_too_large_msg
)
class IncorrectOperandType(InvalidUpdateExpression):
inv_operand_msg = "Incorrect operand type for operator or function; operator or function: {f}, operand type: {t}"
def __init__(self, operator_or_function, operand_type):
self.operator_or_function = operator_or_function
self.operand_type = operand_type
super(IncorrectOperandType, self).__init__(
self.inv_operand_msg.format(f=operator_or_function, t=operand_type)
)
class IncorrectDataType(MockValidationException):
inc_data_type_msg = "An operand in the update expression has an incorrect data type"
def __init__(self):
super(IncorrectDataType, self).__init__(self.inc_data_type_msg)
class ConditionalCheckFailed(ValueError):
msg = "The conditional request failed"
def __init__(self):
super(ConditionalCheckFailed, self).__init__(self.msg)
class TransactionCanceledException(ValueError):
cancel_reason_msg = "Transaction cancelled, please refer cancellation reasons for specific reasons [{}]"
def __init__(self, errors):
msg = self.cancel_reason_msg.format(", ".join([str(err) for err in errors]))
super(TransactionCanceledException, self).__init__(msg)
class EmptyKeyAttributeException(MockValidationException):
empty_str_msg = "One or more parameter values were invalid: An AttributeValue may not contain an empty string"
def __init__(self):
super(EmptyKeyAttributeException, self).__init__(self.empty_str_msg)

View File

@ -0,0 +1,317 @@
import six
from moto.dynamodb2.comparisons import get_comparison_func
from moto.dynamodb2.exceptions import InvalidUpdateExpression, IncorrectDataType
from moto.dynamodb2.models.utilities import attribute_is_list, bytesize
class DDBType(object):
"""
Official documentation at https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_AttributeValue.html
"""
BINARY_SET = "BS"
NUMBER_SET = "NS"
STRING_SET = "SS"
STRING = "S"
NUMBER = "N"
MAP = "M"
LIST = "L"
BOOLEAN = "BOOL"
BINARY = "B"
NULL = "NULL"
class DDBTypeConversion(object):
_human_type_mapping = {
val: key.replace("_", " ")
for key, val in DDBType.__dict__.items()
if key.upper() == key
}
@classmethod
def get_human_type(cls, abbreviated_type):
"""
Args:
abbreviated_type(str): An attribute of DDBType
Returns:
str: The human readable form of the DDBType.
"""
try:
human_type_str = cls._human_type_mapping[abbreviated_type]
except KeyError:
raise ValueError(
"Invalid abbreviated_type {at}".format(at=abbreviated_type)
)
return human_type_str
class DynamoType(object):
"""
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
"""
def __init__(self, type_as_dict):
if type(type_as_dict) == DynamoType:
self.type = type_as_dict.type
self.value = type_as_dict.value
else:
self.type = list(type_as_dict)[0]
self.value = list(type_as_dict.values())[0]
if self.is_list():
self.value = [DynamoType(val) for val in self.value]
elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def get(self, key):
if not key:
return self
else:
key_head = key.split(".")[0]
key_tail = ".".join(key.split(".")[1:])
if key_head not in self.value:
self.value[key_head] = DynamoType({"NONE": None})
return self.value[key_head].get(key_tail)
def set(self, key, new_value, index=None):
if index:
index = int(index)
if type(self.value) is not list:
raise InvalidUpdateExpression
if index >= len(self.value):
self.value.append(new_value)
# {'L': [DynamoType, ..]} ==> DynamoType.set()
self.value[min(index, len(self.value) - 1)].set(key, new_value)
else:
attr = (key or "").split(".").pop(0)
attr, list_index = attribute_is_list(attr)
if not key:
# {'S': value} ==> {'S': new_value}
self.type = new_value.type
self.value = new_value.value
else:
if attr not in self.value: # nonexistingattribute
type_of_new_attr = DDBType.MAP if "." in key else new_value.type
self.value[attr] = DynamoType({type_of_new_attr: {}})
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
self.value[attr].set(
".".join(key.split(".")[1:]), new_value, list_index
)
def __contains__(self, item):
if self.type == DDBType.STRING:
return False
try:
self.__getitem__(item)
return True
except KeyError:
return False
def delete(self, key, index=None):
if index:
if not key:
if int(index) < len(self.value):
del self.value[int(index)]
elif "." in key:
self.value[int(index)].delete(".".join(key.split(".")[1:]))
else:
self.value[int(index)].delete(key)
else:
attr = key.split(".")[0]
attr, list_index = attribute_is_list(attr)
if list_index:
self.value[attr].delete(".".join(key.split(".")[1:]), list_index)
elif "." in key:
self.value[attr].delete(".".join(key.split(".")[1:]))
else:
self.value.pop(key)
def filter(self, projection_expressions):
nested_projections = [
expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr
]
if self.is_map():
expressions_to_delete = []
for attr in self.value:
if (
attr not in projection_expressions
and attr not in nested_projections
):
expressions_to_delete.append(attr)
elif attr in nested_projections:
relevant_expressions = [
expr[len(attr + ".") :]
for expr in projection_expressions
if expr.startswith(attr + ".")
]
self.value[attr].filter(relevant_expressions)
for expr in expressions_to_delete:
self.value.pop(expr)
def __hash__(self):
return hash((self.type, self.value))
def __eq__(self, other):
return self.type == other.type and self.value == other.value
def __ne__(self, other):
return self.type != other.type or self.value != other.value
def __lt__(self, other):
return self.cast_value < other.cast_value
def __le__(self, other):
return self.cast_value <= other.cast_value
def __gt__(self, other):
return self.cast_value > other.cast_value
def __ge__(self, other):
return self.cast_value >= other.cast_value
def __repr__(self):
return "DynamoType: {0}".format(self.to_json())
def __add__(self, other):
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.is_number():
self_value = float(self.value) if "." in self.value else int(self.value)
other_value = float(other.value) if "." in other.value else int(other.value)
return DynamoType(
{DDBType.NUMBER: "{v}".format(v=self_value + other_value)}
)
else:
raise IncorrectDataType()
def __sub__(self, other):
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.type == DDBType.NUMBER:
self_value = float(self.value) if "." in self.value else int(self.value)
other_value = float(other.value) if "." in other.value else int(other.value)
return DynamoType(
{DDBType.NUMBER: "{v}".format(v=self_value - other_value)}
)
else:
raise TypeError("Sum only supported for Numbers.")
def __getitem__(self, item):
if isinstance(item, six.string_types):
# If our DynamoType is a map it should be subscriptable with a key
if self.type == DDBType.MAP:
return self.value[item]
elif isinstance(item, int):
# If our DynamoType is a list is should be subscriptable with an index
if self.type == DDBType.LIST:
return self.value[item]
raise TypeError(
"This DynamoType {dt} is not subscriptable by a {it}".format(
dt=self.type, it=type(item)
)
)
def __setitem__(self, key, value):
if isinstance(key, int):
if self.is_list():
if key >= len(self.value):
# DynamoDB doesn't care you are out of box just add it to the end.
self.value.append(value)
else:
self.value[key] = value
elif isinstance(key, six.string_types):
if self.is_map():
self.value[key] = value
else:
raise NotImplementedError("No set_item for {t}".format(t=type(key)))
@property
def cast_value(self):
if self.is_number():
try:
return int(self.value)
except ValueError:
return float(self.value)
elif self.is_set():
sub_type = self.type[0]
return set([DynamoType({sub_type: v}).cast_value for v in self.value])
elif self.is_list():
return [DynamoType(v).cast_value for v in self.value]
elif self.is_map():
return dict([(k, DynamoType(v).cast_value) for k, v in self.value.items()])
else:
return self.value
def child_attr(self, key):
"""
Get Map or List children by key. str for Map, int for List.
Returns DynamoType or None.
"""
if isinstance(key, six.string_types) and self.is_map():
if "." in key and key.split(".")[0] in self.value:
return self.value[key.split(".")[0]].child_attr(
".".join(key.split(".")[1:])
)
elif "." not in key and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, int) and self.is_list():
idx = key
if 0 <= idx < len(self.value):
return DynamoType(self.value[idx])
return None
def size(self):
if self.is_number():
value_size = len(str(self.value))
elif self.is_set():
sub_type = self.type[0]
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
elif self.is_list():
value_size = sum([v.size() for v in self.value])
elif self.is_map():
value_size = sum(
[bytesize(k) + DynamoType(v).size() for k, v in self.value.items()]
)
elif type(self.value) == bool:
value_size = 1
else:
value_size = bytesize(self.value)
return value_size
def to_json(self):
return {self.type: self.value}
def compare(self, range_comparison, range_objs):
"""
Compares this type against comparison filters
"""
range_values = [obj.cast_value for obj in range_objs]
comparison_func = get_comparison_func(range_comparison)
return comparison_func(self.cast_value, *range_values)
def is_number(self):
return self.type == DDBType.NUMBER
def is_set(self):
return self.type in (DDBType.STRING_SET, DDBType.NUMBER_SET, DDBType.BINARY_SET)
def is_list(self):
return self.type == DDBType.LIST
def is_map(self):
return self.type == DDBType.MAP
def same_type(self, other):
return self.type == other.type
def pop(self, key, *args, **kwargs):
if self.is_map() or self.is_list():
self.value.pop(key, *args, **kwargs)
else:
raise TypeError("pop not supported for DynamoType {t}".format(t=self.type))

View File

@ -0,0 +1,17 @@
import re
def bytesize(val):
return len(str(val).encode("utf-8"))
def attribute_is_list(attr):
"""
Checks if attribute denotes a list, and returns the name of the list and the given list index if so
:param attr: attr or attr[index]
:return: attr, index or None
"""
list_index_update = re.match("(.+)\\[([0-9]+)\\]", attr)
if list_index_update:
attr = list_index_update.group(1)
return attr, list_index_update.group(2) if list_index_update else None

View File

@ -0,0 +1,23 @@
# Parsing dev documentation
Parsing happens in a structured manner and happens in different phases.
This document explains these phases.
## 1) Expression gets parsed into a tokenlist (tokenized)
A string gets parsed from left to right and gets converted into a list of tokens.
The tokens are available in `tokens.py`.
## 2) Tokenlist get transformed to expression tree (AST)
This is the parsing of the token list. This parsing will result in an Abstract Syntax Tree (AST).
The different node types are available in `ast_nodes.py`. The AST is a representation that has all
the information that is in the expression but its tree form allows processing it in a structured manner.
## 3) The AST gets validated (full semantic correctness)
The AST is used for validation. The paths and attributes are validated to be correct. At the end of the
validation all the values will be resolved.
## 4) Update Expression gets executed using the validated AST
Finally the AST is used to execute the update expression. There should be no reason for this step to fail
since validation has completed. Due to this we have the update expressions behaving atomically (i.e. all the
actions of the update expresion are performed or none of them are performed).

View File

View File

@ -0,0 +1,360 @@
import abc
from abc import abstractmethod
from collections import deque
import six
from moto.dynamodb2.models import DynamoType
@six.add_metaclass(abc.ABCMeta)
class Node:
def __init__(self, children=None):
self.type = self.__class__.__name__
assert children is None or isinstance(children, list)
self.children = children
self.parent = None
if isinstance(children, list):
for child in children:
if isinstance(child, Node):
child.set_parent(self)
def set_parent(self, parent_node):
self.parent = parent_node
class LeafNode(Node):
"""A LeafNode is a Node where none of the children are Nodes themselves."""
def __init__(self, children=None):
super(LeafNode, self).__init__(children)
@six.add_metaclass(abc.ABCMeta)
class Expression(Node):
"""
Abstract Syntax Tree representing the expression
For the Grammar start here and jump down into the classes at the righ-hand side to look further. Nodes marked with
a star are abstract and won't appear in the final AST.
Expression* => UpdateExpression
Expression* => ConditionExpression
"""
class UpdateExpression(Expression):
"""
UpdateExpression => UpdateExpressionClause*
UpdateExpression => UpdateExpressionClause* UpdateExpression
"""
@six.add_metaclass(abc.ABCMeta)
class UpdateExpressionClause(UpdateExpression):
"""
UpdateExpressionClause* => UpdateExpressionSetClause
UpdateExpressionClause* => UpdateExpressionRemoveClause
UpdateExpressionClause* => UpdateExpressionAddClause
UpdateExpressionClause* => UpdateExpressionDeleteClause
"""
class UpdateExpressionSetClause(UpdateExpressionClause):
"""
UpdateExpressionSetClause => SET SetActions
"""
class UpdateExpressionSetActions(UpdateExpressionClause):
"""
UpdateExpressionSetClause => SET SetActions
SetActions => SetAction
SetActions => SetAction , SetActions
"""
class UpdateExpressionSetAction(UpdateExpressionClause):
"""
SetAction => Path = Value
"""
class UpdateExpressionRemoveActions(UpdateExpressionClause):
"""
UpdateExpressionSetClause => REMOVE RemoveActions
RemoveActions => RemoveAction
RemoveActions => RemoveAction , RemoveActions
"""
class UpdateExpressionRemoveAction(UpdateExpressionClause):
"""
RemoveAction => Path
"""
class UpdateExpressionAddActions(UpdateExpressionClause):
"""
UpdateExpressionAddClause => ADD RemoveActions
AddActions => AddAction
AddActions => AddAction , AddActions
"""
class UpdateExpressionAddAction(UpdateExpressionClause):
"""
AddAction => Path Value
"""
class UpdateExpressionDeleteActions(UpdateExpressionClause):
"""
UpdateExpressionDeleteClause => DELETE RemoveActions
DeleteActions => DeleteAction
DeleteActions => DeleteAction , DeleteActions
"""
class UpdateExpressionDeleteAction(UpdateExpressionClause):
"""
DeleteAction => Path Value
"""
class UpdateExpressionPath(UpdateExpressionClause):
pass
class UpdateExpressionValue(UpdateExpressionClause):
"""
Value => Operand
Value => Operand + Value
Value => Operand - Value
"""
class UpdateExpressionGroupedValue(UpdateExpressionClause):
"""
GroupedValue => ( Value )
"""
class UpdateExpressionRemoveClause(UpdateExpressionClause):
"""
UpdateExpressionRemoveClause => REMOVE RemoveActions
"""
class UpdateExpressionAddClause(UpdateExpressionClause):
"""
UpdateExpressionAddClause => ADD AddActions
"""
class UpdateExpressionDeleteClause(UpdateExpressionClause):
"""
UpdateExpressionDeleteClause => DELETE DeleteActions
"""
class ExpressionPathDescender(Node):
"""Node identifying descender into nested structure (.) in expression"""
class ExpressionSelector(LeafNode):
"""Node identifying selector [selection_index] in expresion"""
def __init__(self, selection_index):
try:
super(ExpressionSelector, self).__init__(children=[int(selection_index)])
except ValueError:
assert (
False
), "Expression selector must be an int, this is a bug in the moto library."
def get_index(self):
return self.children[0]
class ExpressionAttribute(LeafNode):
"""An attribute identifier as used in the DDB item"""
def __init__(self, attribute):
super(ExpressionAttribute, self).__init__(children=[attribute])
def get_attribute_name(self):
return self.children[0]
class ExpressionAttributeName(LeafNode):
"""An ExpressionAttributeName is an alias for an attribute identifier"""
def __init__(self, attribute_name):
super(ExpressionAttributeName, self).__init__(children=[attribute_name])
def get_attribute_name_placeholder(self):
return self.children[0]
class ExpressionAttributeValue(LeafNode):
"""An ExpressionAttributeValue is an alias for an value"""
def __init__(self, value):
super(ExpressionAttributeValue, self).__init__(children=[value])
def get_value_name(self):
return self.children[0]
class ExpressionValueOperator(LeafNode):
"""An ExpressionValueOperator is an operation that works on 2 values"""
def __init__(self, value):
super(ExpressionValueOperator, self).__init__(children=[value])
def get_operator(self):
return self.children[0]
class UpdateExpressionFunction(Node):
"""
A Node representing a function of an Update Expression. The first child is the function name the others are the
arguments.
"""
def get_function_name(self):
return self.children[0]
def get_nth_argument(self, n=1):
"""Return nth element where n is a 1-based index."""
assert n >= 1
return self.children[n]
class DDBTypedValue(Node):
"""
A node representing a DDBTyped value. This can be any structure as supported by DyanmoDB. The node only has 1 child
which is the value of type `DynamoType`.
"""
def __init__(self, value):
assert isinstance(value, DynamoType), "DDBTypedValue must be of DynamoType"
super(DDBTypedValue, self).__init__(children=[value])
def get_value(self):
return self.children[0]
class NoneExistingPath(LeafNode):
"""A placeholder for Paths that did not exist in the Item."""
def __init__(self, creatable=False):
super(NoneExistingPath, self).__init__(children=[creatable])
def is_creatable(self):
"""Can this path be created if need be. For example path creating element in a dictionary or creating a new
attribute under root level of an item."""
return self.children[0]
class DepthFirstTraverser(object):
"""
Helper class that allows depth first traversal and to implement custom processing for certain AST nodes. The
processor of a node must return the new resulting node. This node will be placed in the tree. Processing of a
node using this traverser should therefore only transform child nodes. The returned node will get the same parent
as the node before processing had.
"""
@abstractmethod
def _processing_map(self):
"""
A map providing a processing function per node class type to a function that takes in a Node object and
processes it. A Node can only be processed by a single function and they are considered in order. Therefore if
multiple classes from a single class hierarchy strain are used the more specific classes have to be put before
the less specific ones. That requires overriding `nodes_to_be_processed`. If no multiple classes form a single
class hierarchy strain are used the default implementation of `nodes_to_be_processed` should be OK.
Returns:
dict: Mapping a Node Class to a processing function.
"""
pass
def nodes_to_be_processed(self):
"""Cached accessor for getting Node types that need to be processed."""
return tuple(k for k in self._processing_map().keys())
def process(self, node):
"""Process a Node"""
for class_key, processor in self._processing_map().items():
if isinstance(node, class_key):
return processor(node)
def pre_processing_of_child(self, parent_node, child_id):
"""Hook that is called pre-processing of the child at position `child_id`"""
pass
def traverse_node_recursively(self, node, child_id=-1):
"""
Traverse nodes depth first processing nodes bottom up (if root node is considered the top).
Args:
node(Node): The node which is the last node to be processed but which allows to identify all the
work (which is in the children)
child_id(int): The index in the list of children from the parent that this node corresponds to
Returns:
Node: The node of the new processed AST
"""
if isinstance(node, Node):
parent_node = node.parent
if node.children is not None:
for i, child_node in enumerate(node.children):
self.pre_processing_of_child(node, i)
self.traverse_node_recursively(child_node, i)
# noinspection PyTypeChecker
if isinstance(node, self.nodes_to_be_processed()):
node = self.process(node)
node.parent = parent_node
parent_node.children[child_id] = node
return node
def traverse(self, node):
return self.traverse_node_recursively(node)
class NodeDepthLeftTypeFetcher(object):
"""Helper class to fetch a node of a specific type. Depth left-first traversal"""
def __init__(self, node_type, root_node):
assert issubclass(node_type, Node)
self.node_type = node_type
self.root_node = root_node
self.queue = deque()
self.add_nodes_left_to_right_depth_first(self.root_node)
def add_nodes_left_to_right_depth_first(self, node):
if isinstance(node, Node) and node.children is not None:
for child_node in node.children:
self.add_nodes_left_to_right_depth_first(child_node)
self.queue.append(child_node)
self.queue.append(node)
def __iter__(self):
return self
def next(self):
return self.__next__()
def __next__(self):
while len(self.queue) > 0:
candidate = self.queue.popleft()
if isinstance(candidate, self.node_type):
return candidate
else:
raise StopIteration

View File

@ -0,0 +1,288 @@
from abc import abstractmethod
from moto.dynamodb2.exceptions import (
IncorrectOperandType,
IncorrectDataType,
ProvidedKeyDoesNotExist,
)
from moto.dynamodb2.models import DynamoType
from moto.dynamodb2.models.dynamo_type import DDBTypeConversion, DDBType
from moto.dynamodb2.parsing.ast_nodes import (
UpdateExpressionSetAction,
UpdateExpressionDeleteAction,
UpdateExpressionRemoveAction,
UpdateExpressionAddAction,
UpdateExpressionPath,
DDBTypedValue,
ExpressionAttribute,
ExpressionSelector,
ExpressionAttributeName,
)
from moto.dynamodb2.parsing.validators import ExpressionPathResolver
class NodeExecutor(object):
def __init__(self, ast_node, expression_attribute_names):
self.node = ast_node
self.expression_attribute_names = expression_attribute_names
@abstractmethod
def execute(self, item):
pass
def get_item_part_for_path_nodes(self, item, path_nodes):
"""
For a list of path nodes travers the item by following the path_nodes
Args:
item(Item):
path_nodes(list):
Returns:
"""
if len(path_nodes) == 0:
return item.attrs
else:
return ExpressionPathResolver(
self.expression_attribute_names
).resolve_expression_path_nodes_to_dynamo_type(item, path_nodes)
def get_item_before_end_of_path(self, item):
"""
Get the part ot the item where the item will perform the action. For most actions this should be the parent. As
that element will need to be modified by the action.
Args:
item(Item):
Returns:
DynamoType or dict: The path to be set
"""
return self.get_item_part_for_path_nodes(
item, self.get_path_expression_nodes()[:-1]
)
def get_item_at_end_of_path(self, item):
"""
For a DELETE the path points at the stringset so we need to evaluate the full path.
Args:
item(Item):
Returns:
DynamoType or dict: The path to be set
"""
return self.get_item_part_for_path_nodes(item, self.get_path_expression_nodes())
# Get the part ot the item where the item will perform the action. For most actions this should be the parent. As
# that element will need to be modified by the action.
get_item_part_in_which_to_perform_action = get_item_before_end_of_path
def get_path_expression_nodes(self):
update_expression_path = self.node.children[0]
assert isinstance(update_expression_path, UpdateExpressionPath)
return update_expression_path.children
def get_element_to_action(self):
return self.get_path_expression_nodes()[-1]
def get_action_value(self):
"""
Returns:
DynamoType: The value to be set
"""
ddb_typed_value = self.node.children[1]
assert isinstance(ddb_typed_value, DDBTypedValue)
dynamo_type_value = ddb_typed_value.children[0]
assert isinstance(dynamo_type_value, DynamoType)
return dynamo_type_value
class SetExecutor(NodeExecutor):
def execute(self, item):
self.set(
item_part_to_modify_with_set=self.get_item_part_in_which_to_perform_action(
item
),
element_to_set=self.get_element_to_action(),
value_to_set=self.get_action_value(),
expression_attribute_names=self.expression_attribute_names,
)
@classmethod
def set(
cls,
item_part_to_modify_with_set,
element_to_set,
value_to_set,
expression_attribute_names,
):
if isinstance(element_to_set, ExpressionAttribute):
attribute_name = element_to_set.get_attribute_name()
item_part_to_modify_with_set[attribute_name] = value_to_set
elif isinstance(element_to_set, ExpressionSelector):
index = element_to_set.get_index()
item_part_to_modify_with_set[index] = value_to_set
elif isinstance(element_to_set, ExpressionAttributeName):
attribute_name = expression_attribute_names[
element_to_set.get_attribute_name_placeholder()
]
item_part_to_modify_with_set[attribute_name] = value_to_set
else:
raise NotImplementedError(
"Moto does not support setting {t} yet".format(t=type(element_to_set))
)
class DeleteExecutor(NodeExecutor):
operator = "operator: DELETE"
def execute(self, item):
string_set_to_remove = self.get_action_value()
assert isinstance(string_set_to_remove, DynamoType)
if not string_set_to_remove.is_set():
raise IncorrectOperandType(
self.operator,
DDBTypeConversion.get_human_type(string_set_to_remove.type),
)
string_set = self.get_item_at_end_of_path(item)
assert isinstance(string_set, DynamoType)
if string_set.type != string_set_to_remove.type:
raise IncorrectDataType()
# String set is currently implemented as a list
string_set_list = string_set.value
stringset_to_remove_list = string_set_to_remove.value
for value in stringset_to_remove_list:
try:
string_set_list.remove(value)
except (KeyError, ValueError):
# DynamoDB does not mind if value is not present
pass
# DynamoDB does not support empty sets. If we've deleted
# the last item in the set, we have to remove the attribute.
if not string_set_list:
element = self.get_element_to_action()
container = self.get_item_before_end_of_path(item)
container.pop(element.get_attribute_name())
class RemoveExecutor(NodeExecutor):
def execute(self, item):
element_to_remove = self.get_element_to_action()
if isinstance(element_to_remove, ExpressionAttribute):
attribute_name = element_to_remove.get_attribute_name()
self.get_item_part_in_which_to_perform_action(item).pop(
attribute_name, None
)
elif isinstance(element_to_remove, ExpressionAttributeName):
attribute_name = self.expression_attribute_names[
element_to_remove.get_attribute_name_placeholder()
]
self.get_item_part_in_which_to_perform_action(item).pop(
attribute_name, None
)
elif isinstance(element_to_remove, ExpressionSelector):
index = element_to_remove.get_index()
try:
self.get_item_part_in_which_to_perform_action(item).pop(index)
except IndexError:
# DynamoDB does not care that index is out of bounds, it will just do nothing.
pass
else:
raise NotImplementedError(
"Moto does not support setting {t} yet".format(
t=type(element_to_remove)
)
)
class AddExecutor(NodeExecutor):
def execute(self, item):
value_to_add = self.get_action_value()
if isinstance(value_to_add, DynamoType):
if value_to_add.is_set():
try:
current_string_set = self.get_item_at_end_of_path(item)
except ProvidedKeyDoesNotExist:
current_string_set = DynamoType({value_to_add.type: []})
SetExecutor.set(
item_part_to_modify_with_set=self.get_item_before_end_of_path(
item
),
element_to_set=self.get_element_to_action(),
value_to_set=current_string_set,
expression_attribute_names=self.expression_attribute_names,
)
assert isinstance(current_string_set, DynamoType)
if not current_string_set.type == value_to_add.type:
raise IncorrectDataType()
# Sets are implemented as list
for value in value_to_add.value:
if value in current_string_set.value:
continue
else:
current_string_set.value.append(value)
elif value_to_add.type == DDBType.NUMBER:
try:
existing_value = self.get_item_at_end_of_path(item)
except ProvidedKeyDoesNotExist:
existing_value = DynamoType({DDBType.NUMBER: "0"})
assert isinstance(existing_value, DynamoType)
if not existing_value.type == DDBType.NUMBER:
raise IncorrectDataType()
new_value = existing_value + value_to_add
SetExecutor.set(
item_part_to_modify_with_set=self.get_item_before_end_of_path(item),
element_to_set=self.get_element_to_action(),
value_to_set=new_value,
expression_attribute_names=self.expression_attribute_names,
)
else:
raise IncorrectDataType()
class UpdateExpressionExecutor(object):
execution_map = {
UpdateExpressionSetAction: SetExecutor,
UpdateExpressionAddAction: AddExecutor,
UpdateExpressionRemoveAction: RemoveExecutor,
UpdateExpressionDeleteAction: DeleteExecutor,
}
def __init__(self, update_ast, item, expression_attribute_names):
self.update_ast = update_ast
self.item = item
self.expression_attribute_names = expression_attribute_names
def execute(self, node=None):
"""
As explained in moto.dynamodb2.parsing.expressions.NestableExpressionParserMixin._create_node the order of nodes
in the AST can be translated of the order of statements in the expression. As such we can start at the root node
and process the nodes 1-by-1. If no specific execution for the node type is defined we can execute the children
in order since it will be a container node that is expandable and left child will be first in the statement.
Args:
node(Node):
Returns:
None
"""
if node is None:
node = self.update_ast
node_executor = self.get_specific_execution(node)
if node_executor is None:
for node in node.children:
self.execute(node)
else:
node_executor(node, self.expression_attribute_names).execute(self.item)
def get_specific_execution(self, node):
for node_class in self.execution_map:
if isinstance(node, node_class):
return self.execution_map[node_class]
return None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,29 @@
class ReservedKeywords(list):
"""
DynamoDB has an extensive list of keywords. Keywords are considered when validating the expression Tree.
Not earlier since an update expression like "SET path = VALUE 1" fails with:
'Invalid UpdateExpression: Syntax error; token: "1", near: "VALUE 1"'
"""
KEYWORDS = None
@classmethod
def get_reserved_keywords(cls):
if cls.KEYWORDS is None:
cls.KEYWORDS = cls._get_reserved_keywords()
return cls.KEYWORDS
@classmethod
def _get_reserved_keywords(cls):
"""
Get a list of reserved keywords of DynamoDB
"""
try:
import importlib.resources as pkg_resources
except ImportError:
import importlib_resources as pkg_resources
reserved_keywords = pkg_resources.read_text(
"moto.dynamodb2.parsing", "reserved_keywords.txt"
)
return reserved_keywords.split()

View File

@ -0,0 +1,573 @@
ABORT
ABSOLUTE
ACTION
ADD
AFTER
AGENT
AGGREGATE
ALL
ALLOCATE
ALTER
ANALYZE
AND
ANY
ARCHIVE
ARE
ARRAY
AS
ASC
ASCII
ASENSITIVE
ASSERTION
ASYMMETRIC
AT
ATOMIC
ATTACH
ATTRIBUTE
AUTH
AUTHORIZATION
AUTHORIZE
AUTO
AVG
BACK
BACKUP
BASE
BATCH
BEFORE
BEGIN
BETWEEN
BIGINT
BINARY
BIT
BLOB
BLOCK
BOOLEAN
BOTH
BREADTH
BUCKET
BULK
BY
BYTE
CALL
CALLED
CALLING
CAPACITY
CASCADE
CASCADED
CASE
CAST
CATALOG
CHAR
CHARACTER
CHECK
CLASS
CLOB
CLOSE
CLUSTER
CLUSTERED
CLUSTERING
CLUSTERS
COALESCE
COLLATE
COLLATION
COLLECTION
COLUMN
COLUMNS
COMBINE
COMMENT
COMMIT
COMPACT
COMPILE
COMPRESS
CONDITION
CONFLICT
CONNECT
CONNECTION
CONSISTENCY
CONSISTENT
CONSTRAINT
CONSTRAINTS
CONSTRUCTOR
CONSUMED
CONTINUE
CONVERT
COPY
CORRESPONDING
COUNT
COUNTER
CREATE
CROSS
CUBE
CURRENT
CURSOR
CYCLE
DATA
DATABASE
DATE
DATETIME
DAY
DEALLOCATE
DEC
DECIMAL
DECLARE
DEFAULT
DEFERRABLE
DEFERRED
DEFINE
DEFINED
DEFINITION
DELETE
DELIMITED
DEPTH
DEREF
DESC
DESCRIBE
DESCRIPTOR
DETACH
DETERMINISTIC
DIAGNOSTICS
DIRECTORIES
DISABLE
DISCONNECT
DISTINCT
DISTRIBUTE
DO
DOMAIN
DOUBLE
DROP
DUMP
DURATION
DYNAMIC
EACH
ELEMENT
ELSE
ELSEIF
EMPTY
ENABLE
END
EQUAL
EQUALS
ERROR
ESCAPE
ESCAPED
EVAL
EVALUATE
EXCEEDED
EXCEPT
EXCEPTION
EXCEPTIONS
EXCLUSIVE
EXEC
EXECUTE
EXISTS
EXIT
EXPLAIN
EXPLODE
EXPORT
EXPRESSION
EXTENDED
EXTERNAL
EXTRACT
FAIL
FALSE
FAMILY
FETCH
FIELDS
FILE
FILTER
FILTERING
FINAL
FINISH
FIRST
FIXED
FLATTERN
FLOAT
FOR
FORCE
FOREIGN
FORMAT
FORWARD
FOUND
FREE
FROM
FULL
FUNCTION
FUNCTIONS
GENERAL
GENERATE
GET
GLOB
GLOBAL
GO
GOTO
GRANT
GREATER
GROUP
GROUPING
HANDLER
HASH
HAVE
HAVING
HEAP
HIDDEN
HOLD
HOUR
IDENTIFIED
IDENTITY
IF
IGNORE
IMMEDIATE
IMPORT
IN
INCLUDING
INCLUSIVE
INCREMENT
INCREMENTAL
INDEX
INDEXED
INDEXES
INDICATOR
INFINITE
INITIALLY
INLINE
INNER
INNTER
INOUT
INPUT
INSENSITIVE
INSERT
INSTEAD
INT
INTEGER
INTERSECT
INTERVAL
INTO
INVALIDATE
IS
ISOLATION
ITEM
ITEMS
ITERATE
JOIN
KEY
KEYS
LAG
LANGUAGE
LARGE
LAST
LATERAL
LEAD
LEADING
LEAVE
LEFT
LENGTH
LESS
LEVEL
LIKE
LIMIT
LIMITED
LINES
LIST
LOAD
LOCAL
LOCALTIME
LOCALTIMESTAMP
LOCATION
LOCATOR
LOCK
LOCKS
LOG
LOGED
LONG
LOOP
LOWER
MAP
MATCH
MATERIALIZED
MAX
MAXLEN
MEMBER
MERGE
METHOD
METRICS
MIN
MINUS
MINUTE
MISSING
MOD
MODE
MODIFIES
MODIFY
MODULE
MONTH
MULTI
MULTISET
NAME
NAMES
NATIONAL
NATURAL
NCHAR
NCLOB
NEW
NEXT
NO
NONE
NOT
NULL
NULLIF
NUMBER
NUMERIC
OBJECT
OF
OFFLINE
OFFSET
OLD
ON
ONLINE
ONLY
OPAQUE
OPEN
OPERATOR
OPTION
OR
ORDER
ORDINALITY
OTHER
OTHERS
OUT
OUTER
OUTPUT
OVER
OVERLAPS
OVERRIDE
OWNER
PAD
PARALLEL
PARAMETER
PARAMETERS
PARTIAL
PARTITION
PARTITIONED
PARTITIONS
PATH
PERCENT
PERCENTILE
PERMISSION
PERMISSIONS
PIPE
PIPELINED
PLAN
POOL
POSITION
PRECISION
PREPARE
PRESERVE
PRIMARY
PRIOR
PRIVATE
PRIVILEGES
PROCEDURE
PROCESSED
PROJECT
PROJECTION
PROPERTY
PROVISIONING
PUBLIC
PUT
QUERY
QUIT
QUORUM
RAISE
RANDOM
RANGE
RANK
RAW
READ
READS
REAL
REBUILD
RECORD
RECURSIVE
REDUCE
REF
REFERENCE
REFERENCES
REFERENCING
REGEXP
REGION
REINDEX
RELATIVE
RELEASE
REMAINDER
RENAME
REPEAT
REPLACE
REQUEST
RESET
RESIGNAL
RESOURCE
RESPONSE
RESTORE
RESTRICT
RESULT
RETURN
RETURNING
RETURNS
REVERSE
REVOKE
RIGHT
ROLE
ROLES
ROLLBACK
ROLLUP
ROUTINE
ROW
ROWS
RULE
RULES
SAMPLE
SATISFIES
SAVE
SAVEPOINT
SCAN
SCHEMA
SCOPE
SCROLL
SEARCH
SECOND
SECTION
SEGMENT
SEGMENTS
SELECT
SELF
SEMI
SENSITIVE
SEPARATE
SEQUENCE
SERIALIZABLE
SESSION
SET
SETS
SHARD
SHARE
SHARED
SHORT
SHOW
SIGNAL
SIMILAR
SIZE
SKEWED
SMALLINT
SNAPSHOT
SOME
SOURCE
SPACE
SPACES
SPARSE
SPECIFIC
SPECIFICTYPE
SPLIT
SQL
SQLCODE
SQLERROR
SQLEXCEPTION
SQLSTATE
SQLWARNING
START
STATE
STATIC
STATUS
STORAGE
STORE
STORED
STREAM
STRING
STRUCT
STYLE
SUB
SUBMULTISET
SUBPARTITION
SUBSTRING
SUBTYPE
SUM
SUPER
SYMMETRIC
SYNONYM
SYSTEM
TABLE
TABLESAMPLE
TEMP
TEMPORARY
TERMINATED
TEXT
THAN
THEN
THROUGHPUT
TIME
TIMESTAMP
TIMEZONE
TINYINT
TO
TOKEN
TOTAL
TOUCH
TRAILING
TRANSACTION
TRANSFORM
TRANSLATE
TRANSLATION
TREAT
TRIGGER
TRIM
TRUE
TRUNCATE
TTL
TUPLE
TYPE
UNDER
UNDO
UNION
UNIQUE
UNIT
UNKNOWN
UNLOGGED
UNNEST
UNPROCESSED
UNSIGNED
UNTIL
UPDATE
UPPER
URL
USAGE
USE
USER
USERS
USING
UUID
VACUUM
VALUE
VALUED
VALUES
VARCHAR
VARIABLE
VARIANCE
VARINT
VARYING
VIEW
VIEWS
VIRTUAL
VOID
WAIT
WHEN
WHENEVER
WHERE
WHILE
WINDOW
WITH
WITHIN
WITHOUT
WORK
WRAPPED
WRITE
YEAR
ZONE

View File

@ -0,0 +1,223 @@
import re
import sys
from moto.dynamodb2.exceptions import (
InvalidTokenException,
InvalidExpressionAttributeNameKey,
)
class Token(object):
_TOKEN_INSTANCE = None
MINUS_SIGN = "-"
PLUS_SIGN = "+"
SPACE_SIGN = " "
EQUAL_SIGN = "="
OPEN_ROUND_BRACKET = "("
CLOSE_ROUND_BRACKET = ")"
COMMA = ","
SPACE = " "
DOT = "."
OPEN_SQUARE_BRACKET = "["
CLOSE_SQUARE_BRACKET = "]"
SPECIAL_CHARACTERS = [
MINUS_SIGN,
PLUS_SIGN,
SPACE_SIGN,
EQUAL_SIGN,
OPEN_ROUND_BRACKET,
CLOSE_ROUND_BRACKET,
COMMA,
SPACE,
DOT,
OPEN_SQUARE_BRACKET,
CLOSE_SQUARE_BRACKET,
]
# Attribute: an identifier that is an attribute
ATTRIBUTE = 0
# Place holder for attribute name
ATTRIBUTE_NAME = 1
# Placeholder for attribute value starts with :
ATTRIBUTE_VALUE = 2
# WhiteSpace shall be grouped together
WHITESPACE = 3
# Placeholder for a number
NUMBER = 4
PLACEHOLDER_NAMES = {
ATTRIBUTE: "Attribute",
ATTRIBUTE_NAME: "AttributeName",
ATTRIBUTE_VALUE: "AttributeValue",
WHITESPACE: "Whitespace",
NUMBER: "Number",
}
def __init__(self, token_type, value):
assert (
token_type in self.SPECIAL_CHARACTERS
or token_type in self.PLACEHOLDER_NAMES
)
self.type = token_type
self.value = value
def __repr__(self):
if isinstance(self.type, int):
return 'Token("{tt}", "{tv}")'.format(
tt=self.PLACEHOLDER_NAMES[self.type], tv=self.value
)
else:
return 'Token("{tt}", "{tv}")'.format(tt=self.type, tv=self.value)
def __eq__(self, other):
return self.type == other.type and self.value == other.value
class ExpressionTokenizer(object):
"""
Takes a string and returns a list of tokens. While attribute names in DynamoDB must be between 1 and 255 characters
long there are no other restrictions for attribute names. For expressions however there are additional rules. If an
attribute name does not adhere then it must be passed via an ExpressionAttributeName. This tokenizer is aware of the
rules of Expression attributes.
We consider a Token as a tuple which has the tokenType
From https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html
1) If an attribute name begins with a number or contains a space, a special character, or a reserved word, you
must use an expression attribute name to replace that attribute's name in the expression.
=> So spaces,+,- or other special characters do identify tokens in update expressions
2) When using a dot (.) in an attribute name you must use expression-attribute-names. A dot in an expression
will be interpreted as a separator in a document path
3) For a nested structure if you want to use expression_attribute_names you must specify one per part of the
path. Since for members of expression_attribute_names the . is part of the name
"""
@classmethod
def is_simple_token_character(cls, character):
return character.isalnum() or character in ("_", ":", "#")
@classmethod
def is_possible_token_boundary(cls, character):
return (
character in Token.SPECIAL_CHARACTERS
or not cls.is_simple_token_character(character)
)
@classmethod
def is_expression_attribute(cls, input_string):
return re.compile("^[a-zA-Z0-9][a-zA-Z0-9_]*$").match(input_string) is not None
@classmethod
def is_expression_attribute_name(cls, input_string):
"""
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html
An expression attribute name must begin with a pound sign (#), and be followed by one or more alphanumeric
characters.
"""
return input_string.startswith("#") and cls.is_expression_attribute(
input_string[1:]
)
@classmethod
def is_expression_attribute_value(cls, input_string):
return re.compile("^:[a-zA-Z0-9_]*$").match(input_string) is not None
def raise_unexpected_token(self):
"""If during parsing an unexpected token is encountered"""
if len(self.token_list) == 0:
near = ""
else:
if len(self.token_list) == 1:
near = self.token_list[-1].value
else:
if self.token_list[-1].type == Token.WHITESPACE:
# Last token was whitespace take 2nd last token value as well to help User orientate
near = self.token_list[-2].value + self.token_list[-1].value
else:
near = self.token_list[-1].value
problematic_token = self.staged_characters[0]
raise InvalidTokenException(problematic_token, near + self.staged_characters)
def __init__(self, input_expression_str):
self.input_expression_str = input_expression_str
self.token_list = []
self.staged_characters = ""
@classmethod
def is_py2(cls):
return sys.version_info[0] == 2
@classmethod
def make_list(cls, input_expression_str):
if cls.is_py2():
pass
else:
assert isinstance(input_expression_str, str)
return ExpressionTokenizer(input_expression_str)._make_list()
def add_token(self, token_type, token_value):
self.token_list.append(Token(token_type, token_value))
def add_token_from_stage(self, token_type):
self.add_token(token_type, self.staged_characters)
self.staged_characters = ""
@classmethod
def is_numeric(cls, input_str):
return re.compile("[0-9]+").match(input_str) is not None
def process_staged_characters(self):
if len(self.staged_characters) == 0:
return
if self.staged_characters.startswith("#"):
if self.is_expression_attribute_name(self.staged_characters):
self.add_token_from_stage(Token.ATTRIBUTE_NAME)
else:
raise InvalidExpressionAttributeNameKey(self.staged_characters)
elif self.is_numeric(self.staged_characters):
self.add_token_from_stage(Token.NUMBER)
elif self.is_expression_attribute(self.staged_characters):
self.add_token_from_stage(Token.ATTRIBUTE)
elif self.is_expression_attribute_value(self.staged_characters):
self.add_token_from_stage(Token.ATTRIBUTE_VALUE)
else:
self.raise_unexpected_token()
def _make_list(self):
"""
Just go through characters if a character is not a token boundary stage it for adding it as a grouped token
later if it is a tokenboundary process staged characters and then process the token boundary as well.
"""
for character in self.input_expression_str:
if not self.is_possible_token_boundary(character):
self.staged_characters += character
else:
self.process_staged_characters()
if character == Token.SPACE:
if (
len(self.token_list) > 0
and self.token_list[-1].type == Token.WHITESPACE
):
self.token_list[-1].value = (
self.token_list[-1].value + character
)
else:
self.add_token(Token.WHITESPACE, character)
elif character in Token.SPECIAL_CHARACTERS:
self.add_token(character, character)
elif not self.is_simple_token_character(character):
self.staged_characters += character
self.raise_unexpected_token()
else:
raise NotImplementedError(
"Encountered character which was not implemented : " + character
)
self.process_staged_characters()
return self.token_list

View File

@ -0,0 +1,394 @@
"""
See docstring class Validator below for more details on validation
"""
from abc import abstractmethod
from copy import deepcopy
from moto.dynamodb2.exceptions import (
AttributeIsReservedKeyword,
ExpressionAttributeValueNotDefined,
AttributeDoesNotExist,
ExpressionAttributeNameNotDefined,
IncorrectOperandType,
InvalidUpdateExpressionInvalidDocumentPath,
ProvidedKeyDoesNotExist,
EmptyKeyAttributeException,
)
from moto.dynamodb2.models import DynamoType
from moto.dynamodb2.parsing.ast_nodes import (
ExpressionAttribute,
UpdateExpressionPath,
UpdateExpressionSetAction,
UpdateExpressionAddAction,
UpdateExpressionDeleteAction,
UpdateExpressionRemoveAction,
DDBTypedValue,
ExpressionAttributeValue,
ExpressionAttributeName,
DepthFirstTraverser,
NoneExistingPath,
UpdateExpressionFunction,
ExpressionPathDescender,
UpdateExpressionValue,
ExpressionValueOperator,
ExpressionSelector,
)
from moto.dynamodb2.parsing.reserved_keywords import ReservedKeywords
class ExpressionAttributeValueProcessor(DepthFirstTraverser):
def __init__(self, expression_attribute_values):
self.expression_attribute_values = expression_attribute_values
def _processing_map(self):
return {
ExpressionAttributeValue: self.replace_expression_attribute_value_with_value
}
def replace_expression_attribute_value_with_value(self, node):
"""A node representing an Expression Attribute Value. Resolve and replace value"""
assert isinstance(node, ExpressionAttributeValue)
attribute_value_name = node.get_value_name()
try:
target = self.expression_attribute_values[attribute_value_name]
except KeyError:
raise ExpressionAttributeValueNotDefined(
attribute_value=attribute_value_name
)
return DDBTypedValue(DynamoType(target))
class ExpressionPathResolver(object):
def __init__(self, expression_attribute_names):
self.expression_attribute_names = expression_attribute_names
@classmethod
def raise_exception_if_keyword(cls, attribute):
if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise AttributeIsReservedKeyword(attribute)
def resolve_expression_path(self, item, update_expression_path):
assert isinstance(update_expression_path, UpdateExpressionPath)
return self.resolve_expression_path_nodes(item, update_expression_path.children)
def resolve_expression_path_nodes(self, item, update_expression_path_nodes):
target = item.attrs
for child in update_expression_path_nodes:
# First replace placeholder with attribute_name
attr_name = None
if isinstance(child, ExpressionAttributeName):
attr_placeholder = child.get_attribute_name_placeholder()
try:
attr_name = self.expression_attribute_names[attr_placeholder]
except KeyError:
raise ExpressionAttributeNameNotDefined(attr_placeholder)
elif isinstance(child, ExpressionAttribute):
attr_name = child.get_attribute_name()
self.raise_exception_if_keyword(attr_name)
if attr_name is not None:
# Resolv attribute_name
try:
target = target[attr_name]
except (KeyError, TypeError):
if child == update_expression_path_nodes[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
if isinstance(child, ExpressionPathDescender):
continue
elif isinstance(child, ExpressionSelector):
index = child.get_index()
if target.is_list():
try:
target = target[index]
except IndexError:
# When a list goes out of bounds when assigning that is no problem when at the assignment
# side. It will just append to the list.
if child == update_expression_path_nodes[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
raise InvalidUpdateExpressionInvalidDocumentPath
else:
raise NotImplementedError(
"Path resolution for {t}".format(t=type(child))
)
if not isinstance(target, DynamoType):
print(target)
return DDBTypedValue(target)
def resolve_expression_path_nodes_to_dynamo_type(
self, item, update_expression_path_nodes
):
node = self.resolve_expression_path_nodes(item, update_expression_path_nodes)
if isinstance(node, NoneExistingPath):
raise ProvidedKeyDoesNotExist()
assert isinstance(node, DDBTypedValue)
return node.get_value()
class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
def _processing_map(self):
return {
UpdateExpressionSetAction: self.disable_resolving,
UpdateExpressionPath: self.process_expression_path_node,
}
def __init__(self, expression_attribute_names, item):
self.expression_attribute_names = expression_attribute_names
self.item = item
self.resolving = False
def pre_processing_of_child(self, parent_node, child_id):
"""
We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first.
Because first argument is path to be set, 2nd argument would be the value.
"""
if isinstance(
parent_node,
(
UpdateExpressionSetAction,
UpdateExpressionRemoveAction,
UpdateExpressionDeleteAction,
UpdateExpressionAddAction,
),
):
if child_id == 0:
self.resolving = False
else:
self.resolving = True
def disable_resolving(self, node=None):
self.resolving = False
return node
def process_expression_path_node(self, node):
"""Resolve ExpressionAttribute if not part of a path and resolving is enabled."""
if self.resolving:
return self.resolve_expression_path(node)
else:
# Still resolve but return original note to make sure path is correct Just make sure nodes are creatable.
result_node = self.resolve_expression_path(node)
if (
isinstance(result_node, NoneExistingPath)
and not result_node.is_creatable()
):
raise InvalidUpdateExpressionInvalidDocumentPath()
return node
def resolve_expression_path(self, node):
return ExpressionPathResolver(
self.expression_attribute_names
).resolve_expression_path(self.item, node)
class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
"""
At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET
expression as per the official AWS docs:
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/
Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET
"""
def _processing_map(self):
return {UpdateExpressionFunction: self.process_function}
def process_function(self, node):
assert isinstance(node, UpdateExpressionFunction)
function_name = node.get_function_name()
first_arg = node.get_nth_argument(1)
second_arg = node.get_nth_argument(2)
if function_name == "if_not_exists":
if isinstance(first_arg, NoneExistingPath):
result = second_arg
else:
result = first_arg
assert isinstance(result, (DDBTypedValue, NoneExistingPath))
return result
elif function_name == "list_append":
first_arg = deepcopy(
self.get_list_from_ddb_typed_value(first_arg, function_name)
)
second_arg = self.get_list_from_ddb_typed_value(second_arg, function_name)
for list_element in second_arg.value:
first_arg.value.append(list_element)
return DDBTypedValue(first_arg)
else:
raise NotImplementedError(
"Unsupported function for moto {name}".format(name=function_name)
)
@classmethod
def get_list_from_ddb_typed_value(cls, node, function_name):
assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType)
if not dynamo_value.is_list():
raise IncorrectOperandType(function_name, dynamo_value.type)
return dynamo_value
class NoneExistingPathChecker(DepthFirstTraverser):
"""
Pass through the AST and make sure there are no none-existing paths.
"""
def _processing_map(self):
return {NoneExistingPath: self.raise_none_existing_path}
def raise_none_existing_path(self, node):
raise AttributeDoesNotExist
class ExecuteOperations(DepthFirstTraverser):
def _processing_map(self):
return {UpdateExpressionValue: self.process_update_expression_value}
def process_update_expression_value(self, node):
"""
If an UpdateExpressionValue only has a single child the node will be replaced with the childe.
Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them
Args:
node(Node):
Returns:
Node: The resulting node of the operation if present or the child.
"""
assert isinstance(node, UpdateExpressionValue)
if len(node.children) == 1:
return node.children[0]
elif len(node.children) == 3:
operator_node = node.children[1]
assert isinstance(operator_node, ExpressionValueOperator)
operator = operator_node.get_operator()
left_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[0])
right_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[2])
if operator == "+":
return self.get_sum(left_operand, right_operand)
elif operator == "-":
return self.get_subtraction(left_operand, right_operand)
else:
raise NotImplementedError(
"Moto does not support operator {operator}".format(
operator=operator
)
)
else:
raise NotImplementedError(
"UpdateExpressionValue only has implementations for 1 or 3 children."
)
@classmethod
def get_dynamo_value_from_ddb_typed_value(cls, node):
assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType)
return dynamo_value
@classmethod
def get_sum(cls, left_operand, right_operand):
"""
Args:
left_operand(DynamoType):
right_operand(DynamoType):
Returns:
DDBTypedValue:
"""
try:
return DDBTypedValue(left_operand + right_operand)
except TypeError:
raise IncorrectOperandType("+", left_operand.type)
@classmethod
def get_subtraction(cls, left_operand, right_operand):
"""
Args:
left_operand(DynamoType):
right_operand(DynamoType):
Returns:
DDBTypedValue:
"""
try:
return DDBTypedValue(left_operand - right_operand)
except TypeError:
raise IncorrectOperandType("-", left_operand.type)
class EmptyStringKeyValueValidator(DepthFirstTraverser):
def __init__(self, key_attributes):
self.key_attributes = key_attributes
def _processing_map(self):
return {UpdateExpressionSetAction: self.check_for_empty_string_key_value}
def check_for_empty_string_key_value(self, node):
"""A node representing a SET action. Check that keys are not being assigned empty strings"""
assert isinstance(node, UpdateExpressionSetAction)
assert len(node.children) == 2
key = node.children[0].children[0].children[0]
val_node = node.children[1].children[0]
if val_node.type in ["S", "B"] and key in self.key_attributes:
raise EmptyKeyAttributeException
return node
class Validator(object):
"""
A validator is used to validate expressions which are passed in as an AST.
"""
def __init__(
self,
expression,
expression_attribute_names,
expression_attribute_values,
item,
table,
):
"""
Besides validation the Validator should also replace referenced parts of an item which is cheapest upon
validation.
Args:
expression(Node): The root node of the AST representing the expression to be validated
expression_attribute_names(ExpressionAttributeNames):
expression_attribute_values(ExpressionAttributeValues):
item(Item): The item which will be updated (pointed to by Key of update_item)
"""
self.expression_attribute_names = expression_attribute_names
self.expression_attribute_values = expression_attribute_values
self.item = item
self.table = table
self.processors = self.get_ast_processors()
self.node_to_validate = deepcopy(expression)
@abstractmethod
def get_ast_processors(self):
"""Get the different processors that go through the AST tree and processes the nodes."""
def validate(self):
n = self.node_to_validate
for processor in self.processors:
n = processor.traverse(n)
return n
class UpdateExpressionValidator(Validator):
def get_ast_processors(self):
"""Get the different processors that go through the AST tree and processes the nodes."""
processors = [
ExpressionAttributeValueProcessor(self.expression_attribute_values),
ExpressionAttributeResolvingProcessor(
self.expression_attribute_names, self.item
),
UpdateExpressionFunctionEvaluator(),
NoneExistingPathChecker(),
ExecuteOperations(),
EmptyStringKeyValueValidator(self.table.key_attributes),
]
return processors

View File

@ -1,24 +1,38 @@
from __future__ import unicode_literals
import itertools
import copy
import json
import six
import re
import itertools
import six
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores, amzn_request_id
from .exceptions import InvalidIndexNameError, InvalidUpdateExpression, ItemSizeTooLarge
from .models import dynamodb_backends, dynamo_json_dump
from .exceptions import (
InvalidIndexNameError,
ItemSizeTooLarge,
MockValidationException,
TransactionCanceledException,
)
from moto.dynamodb2.models import dynamodb_backends, dynamo_json_dump
def has_empty_keys_or_values(_dict):
if _dict == "":
return True
if not isinstance(_dict, dict):
return False
return any(
key == "" or value == "" or has_empty_keys_or_values(value)
for key, value in _dict.items()
)
TRANSACTION_MAX_ITEMS = 25
def put_has_empty_keys(field_updates, table):
if table:
key_names = table.key_attributes
# string/binary fields with empty string as value
empty_str_fields = [
key
for (key, val) in field_updates.items()
if next(iter(val.keys())) in ["S", "B"] and next(iter(val.values())) == ""
]
return any([keyname in empty_str_fields for keyname in key_names])
return False
def get_empty_str_error():
@ -86,19 +100,14 @@ class DynamoHandler(BaseResponse):
def list_tables(self):
body = self.body
limit = body.get("Limit", 100)
if body.get("ExclusiveStartTableName"):
last = body.get("ExclusiveStartTableName")
start = list(self.dynamodb_backend.tables.keys()).index(last) + 1
else:
start = 0
all_tables = list(self.dynamodb_backend.tables.keys())
if limit:
tables = all_tables[start : start + limit]
else:
tables = all_tables[start:]
exclusive_start_table_name = body.get("ExclusiveStartTableName")
tables, last_eval = self.dynamodb_backend.list_tables(
limit, exclusive_start_table_name
)
response = {"TableNames": tables}
if limit and len(all_tables) > start + limit:
response["LastEvaluatedTableName"] = tables[-1]
if last_eval:
response["LastEvaluatedTableName"] = last_eval
return dynamo_json_dump(response)
@ -218,33 +227,29 @@ class DynamoHandler(BaseResponse):
def update_table(self):
name = self.body["TableName"]
table = self.dynamodb_backend.get_table(name)
if "GlobalSecondaryIndexUpdates" in self.body:
table = self.dynamodb_backend.update_table_global_indexes(
name, self.body["GlobalSecondaryIndexUpdates"]
global_index = self.body.get("GlobalSecondaryIndexUpdates", None)
throughput = self.body.get("ProvisionedThroughput", None)
stream_spec = self.body.get("StreamSpecification", None)
try:
table = self.dynamodb_backend.update_table(
name=name,
global_index=global_index,
throughput=throughput,
stream_spec=stream_spec,
)
if "ProvisionedThroughput" in self.body:
throughput = self.body["ProvisionedThroughput"]
table = self.dynamodb_backend.update_table_throughput(name, throughput)
if "StreamSpecification" in self.body:
try:
table = self.dynamodb_backend.update_table_streams(
name, self.body["StreamSpecification"]
)
except ValueError:
er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException"
return self.error(er, "Cannot enable stream")
return dynamo_json_dump(table.describe())
return dynamo_json_dump(table.describe())
except ValueError:
er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException"
return self.error(er, "Cannot enable stream")
def describe_table(self):
name = self.body["TableName"]
try:
table = self.dynamodb_backend.tables[name]
table = self.dynamodb_backend.describe_table(name)
return dynamo_json_dump(table)
except KeyError:
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er, "Requested resource not found")
return dynamo_json_dump(table.describe(base_key="Table"))
def put_item(self):
name = self.body["TableName"]
@ -255,7 +260,7 @@ class DynamoHandler(BaseResponse):
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(er, "Return values set to invalid value")
if has_empty_keys_or_values(item):
if put_has_empty_keys(item, self.dynamodb_backend.get_table(name)):
return get_empty_str_error()
overwrite = "Expected" not in self.body
@ -292,12 +297,13 @@ class DynamoHandler(BaseResponse):
)
except ItemSizeTooLarge:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(er, ItemSizeTooLarge.message)
except ValueError:
return self.error(er, ItemSizeTooLarge.item_size_too_large_msg)
except KeyError as ke:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(er, ke.args[0])
except ValueError as ve:
er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException"
return self.error(
er, "A condition specified in the operation could not be evaluated."
)
return self.error(er, str(ve))
if result:
item_dict = result.to_json()
@ -368,6 +374,26 @@ class DynamoHandler(BaseResponse):
results = {"ConsumedCapacity": [], "Responses": {}, "UnprocessedKeys": {}}
# Validation: Can only request up to 100 items at the same time
# Scenario 1: We're requesting more than a 100 keys from a single table
for table_name, table_request in table_batches.items():
if len(table_request["Keys"]) > 100:
return self.error(
"com.amazonaws.dynamodb.v20111205#ValidationException",
"1 validation error detected: Value at 'requestItems."
+ table_name
+ ".member.keys' failed to satisfy constraint: Member must have length less than or equal to 100",
)
# Scenario 2: We're requesting more than a 100 keys across all tables
nr_of_keys_across_all_tables = sum(
[len(req["Keys"]) for _, req in table_batches.items()]
)
if nr_of_keys_across_all_tables > 100:
return self.error(
"com.amazonaws.dynamodb.v20111205#ValidationException",
"Too many items requested for the BatchGetItem call",
)
for table_name, table_request in table_batches.items():
keys = table_request["Keys"]
if self._contains_duplicates(keys):
@ -408,7 +434,6 @@ class DynamoHandler(BaseResponse):
def query(self):
name = self.body["TableName"]
# {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}}
key_condition_expression = self.body.get("KeyConditionExpression")
projection_expression = self.body.get("ProjectionExpression")
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
@ -436,7 +461,7 @@ class DynamoHandler(BaseResponse):
index_name = self.body.get("IndexName")
if index_name:
all_indexes = (table.global_indexes or []) + (table.indexes or [])
indexes_by_name = dict((i["IndexName"], i) for i in all_indexes)
indexes_by_name = dict((i.name, i) for i in all_indexes)
if index_name not in indexes_by_name:
er = "com.amazonaws.dynamodb.v20120810#ResourceNotFoundException"
return self.error(
@ -446,7 +471,7 @@ class DynamoHandler(BaseResponse):
),
)
index = indexes_by_name[index_name]["KeySchema"]
index = indexes_by_name[index_name].schema
else:
index = table.schema
@ -455,8 +480,10 @@ class DynamoHandler(BaseResponse):
for k, v in six.iteritems(self.body.get("ExpressionAttributeNames", {}))
)
if " AND " in key_condition_expression:
expressions = key_condition_expression.split(" AND ", 1)
if " and " in key_condition_expression.lower():
expressions = re.split(
" AND ", key_condition_expression, maxsplit=1, flags=re.IGNORECASE
)
index_hash_key = [key for key in index if key["KeyType"] == "HASH"][0]
hash_key_var = reverse_attribute_lookup.get(
@ -710,7 +737,8 @@ class DynamoHandler(BaseResponse):
attribute_updates = self.body.get("AttributeUpdates")
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
existing_item = self.dynamodb_backend.get_item(name, key)
# We need to copy the item in order to avoid it being modified by the update_item operation
existing_item = copy.deepcopy(self.dynamodb_backend.get_item(name, key))
if existing_item:
existing_attributes = existing_item.to_json()["Attributes"]
else:
@ -726,9 +754,6 @@ class DynamoHandler(BaseResponse):
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(er, "Return values set to invalid value")
if has_empty_keys_or_values(expression_attribute_values):
return get_empty_str_error()
if "Expected" in self.body:
expected = self.body["Expected"]
else:
@ -740,31 +765,20 @@ class DynamoHandler(BaseResponse):
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
if update_expression:
update_expression = re.sub(r"\s*([=\+-])\s*", "\\1", update_expression)
try:
item = self.dynamodb_backend.update_item(
name,
key,
update_expression,
attribute_updates,
expression_attribute_names,
expression_attribute_values,
expected,
condition_expression,
update_expression=update_expression,
attribute_updates=attribute_updates,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
expected=expected,
condition_expression=condition_expression,
)
except InvalidUpdateExpression:
except MockValidationException as mve:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(
er,
"The document path provided in the update expression is invalid for update",
)
except ItemSizeTooLarge:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(er, ItemSizeTooLarge.message)
return self.error(er, mve.exception_msg)
except ValueError:
er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException"
return self.error(
@ -796,14 +810,36 @@ class DynamoHandler(BaseResponse):
k: v for k, v in existing_attributes.items() if k in changed_attributes
}
elif return_values == "UPDATED_NEW":
item_dict["Attributes"] = {
k: v
for k, v in item_dict["Attributes"].items()
if k in changed_attributes
}
item_dict["Attributes"] = self._build_updated_new_attributes(
existing_attributes, item_dict["Attributes"]
)
return dynamo_json_dump(item_dict)
def _build_updated_new_attributes(self, original, changed):
if type(changed) != type(original):
return changed
else:
if type(changed) is dict:
return {
key: self._build_updated_new_attributes(
original.get(key, None), changed[key]
)
for key in changed.keys()
if key not in original or changed[key] != original[key]
}
elif type(changed) in (set, list):
if len(changed) != len(original):
return changed
else:
return [
self._build_updated_new_attributes(
original[index], changed[index]
)
for index in range(len(changed))
]
else:
return changed
def describe_limits(self):
return json.dumps(
{
@ -818,13 +854,117 @@ class DynamoHandler(BaseResponse):
name = self.body["TableName"]
ttl_spec = self.body["TimeToLiveSpecification"]
self.dynamodb_backend.update_ttl(name, ttl_spec)
self.dynamodb_backend.update_time_to_live(name, ttl_spec)
return json.dumps({"TimeToLiveSpecification": ttl_spec})
def describe_time_to_live(self):
name = self.body["TableName"]
ttl_spec = self.dynamodb_backend.describe_ttl(name)
ttl_spec = self.dynamodb_backend.describe_time_to_live(name)
return json.dumps({"TimeToLiveDescription": ttl_spec})
def transact_get_items(self):
transact_items = self.body["TransactItems"]
responses = list()
if len(transact_items) > TRANSACTION_MAX_ITEMS:
msg = "1 validation error detected: Value '["
err_list = list()
request_id = 268435456
for _ in transact_items:
request_id += 1
hex_request_id = format(request_id, "x")
err_list.append(
"com.amazonaws.dynamodb.v20120810.TransactGetItem@%s"
% hex_request_id
)
msg += ", ".join(err_list)
msg += (
"'] at 'transactItems' failed to satisfy constraint: "
"Member must have length less than or equal to %s"
% TRANSACTION_MAX_ITEMS
)
return self.error("ValidationException", msg)
ret_consumed_capacity = self.body.get("ReturnConsumedCapacity", "NONE")
consumed_capacity = dict()
for transact_item in transact_items:
table_name = transact_item["Get"]["TableName"]
key = transact_item["Get"]["Key"]
try:
item = self.dynamodb_backend.get_item(table_name, key)
except ValueError:
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er, "Requested resource not found")
if not item:
responses.append({})
continue
item_describe = item.describe_attrs(False)
responses.append(item_describe)
table_capacity = consumed_capacity.get(table_name, {})
table_capacity["TableName"] = table_name
capacity_units = table_capacity.get("CapacityUnits", 0) + 2.0
table_capacity["CapacityUnits"] = capacity_units
read_capacity_units = table_capacity.get("ReadCapacityUnits", 0) + 2.0
table_capacity["ReadCapacityUnits"] = read_capacity_units
consumed_capacity[table_name] = table_capacity
if ret_consumed_capacity == "INDEXES":
table_capacity["Table"] = {
"CapacityUnits": capacity_units,
"ReadCapacityUnits": read_capacity_units,
}
result = dict()
result.update({"Responses": responses})
if ret_consumed_capacity != "NONE":
result.update({"ConsumedCapacity": [v for v in consumed_capacity.values()]})
return dynamo_json_dump(result)
def transact_write_items(self):
transact_items = self.body["TransactItems"]
try:
self.dynamodb_backend.transact_write_items(transact_items)
except TransactionCanceledException as e:
er = "com.amazonaws.dynamodb.v20111205#TransactionCanceledException"
return self.error(er, str(e))
response = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}}
return dynamo_json_dump(response)
def describe_continuous_backups(self):
name = self.body["TableName"]
if self.dynamodb_backend.get_table(name) is None:
return self.error(
"com.amazonaws.dynamodb.v20111205#TableNotFoundException",
"Table not found: {}".format(name),
)
response = self.dynamodb_backend.describe_continuous_backups(name)
return json.dumps({"ContinuousBackupsDescription": response})
def update_continuous_backups(self):
name = self.body["TableName"]
point_in_time_spec = self.body["PointInTimeRecoverySpecification"]
if self.dynamodb_backend.get_table(name) is None:
return self.error(
"com.amazonaws.dynamodb.v20111205#TableNotFoundException",
"Table not found: {}".format(name),
)
response = self.dynamodb_backend.update_continuous_backups(
name, point_in_time_spec
)
return json.dumps({"ContinuousBackupsDescription": response})

View File

@ -7,7 +7,7 @@ import base64
from boto3 import Session
from moto.core import BaseBackend, BaseModel
from moto.dynamodb2.models import dynamodb_backends
from moto.dynamodb2.models import dynamodb_backends, DynamoJsonEncoder
class ShardIterator(BaseModel):
@ -137,7 +137,7 @@ class DynamoDBStreamsBackend(BaseBackend):
def get_records(self, iterator_arn, limit):
shard_iterator = self.shard_iterators[iterator_arn]
return json.dumps(shard_iterator.get(limit))
return json.dumps(shard_iterator.get(limit), cls=DynamoJsonEncoder)
dynamodbstreams_backends = {}

View File

@ -71,6 +71,24 @@ class InvalidSubnetIdError(EC2ClientError):
)
class InvalidFlowLogIdError(EC2ClientError):
def __init__(self, count, flow_log_ids):
super(InvalidFlowLogIdError, self).__init__(
"InvalidFlowLogId.NotFound",
"These flow log ids in the input list are not found: [TotalCount: {0}] {1}".format(
count, flow_log_ids
),
)
class FlowLogAlreadyExists(EC2ClientError):
def __init__(self):
super(FlowLogAlreadyExists, self).__init__(
"FlowLogAlreadyExists",
"Error. There is an existing Flow Log with the same configuration and log destination.",
)
class InvalidNetworkAclIdError(EC2ClientError):
def __init__(self, network_acl_id):
super(InvalidNetworkAclIdError, self).__init__(
@ -231,6 +249,24 @@ class InvalidVolumeAttachmentError(EC2ClientError):
)
class InvalidVolumeDetachmentError(EC2ClientError):
def __init__(self, volume_id, instance_id, device):
super(InvalidVolumeDetachmentError, self).__init__(
"InvalidAttachment.NotFound",
"The volume {0} is not attached to instance {1} as device {2}".format(
volume_id, instance_id, device
),
)
class VolumeInUseError(EC2ClientError):
def __init__(self, volume_id, instance_id):
super(VolumeInUseError, self).__init__(
"VolumeInUse",
"Volume {0} is currently attached to {1}".format(volume_id, instance_id),
)
class InvalidDomainError(EC2ClientError):
def __init__(self, domain):
super(InvalidDomainError, self).__init__(
@ -245,6 +281,14 @@ class InvalidAddressError(EC2ClientError):
)
class LogDestinationNotFoundError(EC2ClientError):
def __init__(self, bucket_name):
super(LogDestinationNotFoundError, self).__init__(
"LogDestinationNotFoundException",
"LogDestination: '{0}' does not exist.".format(bucket_name),
)
class InvalidAllocationIdError(EC2ClientError):
def __init__(self, allocation_id):
super(InvalidAllocationIdError, self).__init__(
@ -291,6 +335,33 @@ class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
)
class InvalidDependantParameterError(EC2ClientError):
def __init__(self, dependant_parameter, parameter, parameter_value):
super(InvalidDependantParameterError, self).__init__(
"InvalidParameter",
"{0} can't be empty if {1} is {2}.".format(
dependant_parameter, parameter, parameter_value,
),
)
class InvalidDependantParameterTypeError(EC2ClientError):
def __init__(self, dependant_parameter, parameter_value, parameter):
super(InvalidDependantParameterTypeError, self).__init__(
"InvalidParameter",
"{0} type must be {1} if {2} is provided.".format(
dependant_parameter, parameter_value, parameter,
),
)
class InvalidAggregationIntervalParameterError(EC2ClientError):
def __init__(self, parameter):
super(InvalidAggregationIntervalParameterError, self).__init__(
"InvalidParameter", "Invalid {0}".format(parameter),
)
class InvalidParameterValueError(EC2ClientError):
def __init__(self, parameter_value):
super(InvalidParameterValueError, self).__init__(
@ -502,3 +573,29 @@ class InvalidLaunchTemplateNameError(EC2ClientError):
"InvalidLaunchTemplateName.AlreadyExistsException",
"Launch template name already in use.",
)
class InvalidParameterDependency(EC2ClientError):
def __init__(self, param, param_needed):
super(InvalidParameterDependency, self).__init__(
"InvalidParameterDependency",
"The parameter [{0}] requires the parameter {1} to be set.".format(
param, param_needed
),
)
class IncorrectStateIamProfileAssociationError(EC2ClientError):
def __init__(self, instance_id):
super(IncorrectStateIamProfileAssociationError, self).__init__(
"IncorrectState",
"There is an existing association for instance {0}".format(instance_id),
)
class InvalidAssociationIDIamProfileAssociationError(EC2ClientError):
def __init__(self, association_id):
super(InvalidAssociationIDIamProfileAssociationError, self).__init__(
"InvalidAssociationID.NotFound",
"An invalid association-id of '{0}' was given".format(association_id),
)

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -24,6 +24,7 @@ from .security_groups import SecurityGroups
from .spot_fleets import SpotFleets
from .spot_instances import SpotInstances
from .subnets import Subnets
from .flow_logs import FlowLogs
from .tags import TagResponse
from .virtual_private_gateways import VirtualPrivateGateways
from .vm_export import VMExport
@ -33,6 +34,7 @@ from .vpc_peering_connections import VPCPeeringConnections
from .vpn_connections import VPNConnections
from .windows import Windows
from .nat_gateways import NatGateways
from .iam_instance_profiles import IamInstanceProfiles
class EC2Response(
@ -60,6 +62,7 @@ class EC2Response(
SpotFleets,
SpotInstances,
Subnets,
FlowLogs,
TagResponse,
VirtualPrivateGateways,
VMExport,
@ -69,6 +72,7 @@ class EC2Response(
VPNConnections,
Windows,
NatGateways,
IamInstanceProfiles,
):
@property
def ec2_backend(self):

View File

@ -73,8 +73,12 @@ class AmisResponse(BaseResponse):
return MODIFY_IMAGE_ATTRIBUTE_RESPONSE
def register_image(self):
name = self.querystring.get("Name")[0]
description = self._get_param("Description", if_none="")
if self.is_not_dryrun("RegisterImage"):
raise NotImplementedError("AMIs.register_image is not yet implemented")
image = self.ec2_backend.register_image(name, description)
template = self.response_template(REGISTER_IMAGE_RESPONSE)
return template.render(image=image)
def reset_image_attribute(self):
if self.is_not_dryrun("ResetImageAttribute"):
@ -125,7 +129,7 @@ DESCRIBE_IMAGES_RESPONSE = """<DescribeImagesResponse xmlns="http://ec2.amazonaw
<snapshotId>{{ image.ebs_snapshot.id }}</snapshotId>
<volumeSize>15</volumeSize>
<deleteOnTermination>false</deleteOnTermination>
<volumeType>{{ image.root_device_type }}</volumeType>
<volumeType>standard</volumeType>
</ebs>
</item>
</blockDeviceMapping>
@ -190,3 +194,8 @@ MODIFY_IMAGE_ATTRIBUTE_RESPONSE = """
<return>true</return>
</ModifyImageAttributeResponse>
"""
REGISTER_IMAGE_RESPONSE = """<RegisterImageResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<imageId>{{ image.id }}</imageId>
</RegisterImageResponse>"""

View File

@ -22,6 +22,7 @@ DESCRIBE_REGIONS_RESPONSE = """<DescribeRegionsResponse xmlns="http://ec2.amazon
<item>
<regionName>{{ region.name }}</regionName>
<regionEndpoint>{{ region.endpoint }}</regionEndpoint>
<optInStatus>{{ region.opt_in_status }}</optInStatus>
</item>
{% endfor %}
</regionInfo>
@ -35,6 +36,7 @@ DESCRIBE_ZONES_RESPONSE = """<DescribeAvailabilityZonesResponse xmlns="http://ec
<zoneName>{{ zone.name }}</zoneName>
<zoneState>available</zoneState>
<regionName>{{ zone.region_name }}</regionName>
<zoneId>{{ zone.zone_id }}</zoneId>
<messageSet/>
</item>
{% endfor %}

View File

@ -19,10 +19,13 @@ class ElasticBlockStore(BaseResponse):
source_snapshot_id = self._get_param("SourceSnapshotId")
source_region = self._get_param("SourceRegion")
description = self._get_param("Description")
tags = self._parse_tag_specification("TagSpecification")
snapshot_tags = tags.get("snapshot", {})
if self.is_not_dryrun("CopySnapshot"):
snapshot = self.ec2_backend.copy_snapshot(
source_snapshot_id, source_region, description
)
snapshot.add_tags(snapshot_tags)
template = self.response_template(COPY_SNAPSHOT_RESPONSE)
return template.render(snapshot=snapshot)
@ -43,9 +46,12 @@ class ElasticBlockStore(BaseResponse):
snapshot_id = self._get_param("SnapshotId")
tags = self._parse_tag_specification("TagSpecification")
volume_tags = tags.get("volume", {})
encrypted = self._get_param("Encrypted", if_none=False)
encrypted = self._get_bool_param("Encrypted", if_none=False)
kms_key_id = self._get_param("KmsKeyId")
if self.is_not_dryrun("CreateVolume"):
volume = self.ec2_backend.create_volume(size, zone, snapshot_id, encrypted)
volume = self.ec2_backend.create_volume(
size, zone, snapshot_id, encrypted, kms_key_id
)
volume.add_tags(volume_tags)
template = self.response_template(CREATE_VOLUME_RESPONSE)
return template.render(volume=volume)
@ -116,22 +122,23 @@ class ElasticBlockStore(BaseResponse):
def describe_snapshot_attribute(self):
snapshot_id = self._get_param("SnapshotId")
groups = self.ec2_backend.get_create_volume_permission_groups(snapshot_id)
user_ids = self.ec2_backend.get_create_volume_permission_userids(snapshot_id)
template = self.response_template(DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE)
return template.render(snapshot_id=snapshot_id, groups=groups)
return template.render(snapshot_id=snapshot_id, groups=groups, userIds=user_ids)
def modify_snapshot_attribute(self):
snapshot_id = self._get_param("SnapshotId")
operation_type = self._get_param("OperationType")
group = self._get_param("UserGroup.1")
user_id = self._get_param("UserId.1")
groups = self._get_multi_param("UserGroup")
user_ids = self._get_multi_param("UserId")
if self.is_not_dryrun("ModifySnapshotAttribute"):
if operation_type == "add":
self.ec2_backend.add_create_volume_permission(
snapshot_id, user_id=user_id, group=group
snapshot_id, user_ids=user_ids, groups=groups
)
elif operation_type == "remove":
self.ec2_backend.remove_create_volume_permission(
snapshot_id, user_id=user_id, group=group
snapshot_id, user_ids=user_ids, groups=groups
)
return MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE
@ -157,7 +164,10 @@ CREATE_VOLUME_RESPONSE = """<CreateVolumeResponse xmlns="http://ec2.amazonaws.co
{% else %}
<snapshotId/>
{% endif %}
<encrypted>{{ volume.encrypted }}</encrypted>
<encrypted>{{ 'true' if volume.encrypted else 'false' }}</encrypted>
{% if volume.encrypted %}
<kmsKeyId>{{ volume.kms_key_id }}</kmsKeyId>
{% endif %}
<availabilityZone>{{ volume.zone.name }}</availabilityZone>
<status>creating</status>
<createTime>{{ volume.create_time}}</createTime>
@ -188,7 +198,10 @@ DESCRIBE_VOLUMES_RESPONSE = """<DescribeVolumesResponse xmlns="http://ec2.amazon
{% else %}
<snapshotId/>
{% endif %}
<encrypted>{{ volume.encrypted }}</encrypted>
<encrypted>{{ 'true' if volume.encrypted else 'false' }}</encrypted>
{% if volume.encrypted %}
<kmsKeyId>{{ volume.kms_key_id }}</kmsKeyId>
{% endif %}
<availabilityZone>{{ volume.zone.name }}</availabilityZone>
<status>{{ volume.status }}</status>
<createTime>{{ volume.create_time}}</createTime>
@ -271,6 +284,16 @@ CREATE_SNAPSHOT_RESPONSE = """<CreateSnapshotResponse xmlns="http://ec2.amazonaw
COPY_SNAPSHOT_RESPONSE = """<CopySnapshotResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<snapshotId>{{ snapshot.id }}</snapshotId>
<tagSet>
{% for tag in snapshot.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</CopySnapshotResponse>"""
DESCRIBE_SNAPSHOTS_RESPONSE = """<DescribeSnapshotsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -311,18 +334,18 @@ DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE = """
<DescribeSnapshotAttributeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>a9540c9f-161a-45d8-9cc1-1182b89ad69f</requestId>
<snapshotId>snap-a0332ee0</snapshotId>
{% if not groups %}
<createVolumePermission/>
{% endif %}
{% if groups %}
<createVolumePermission>
{% for group in groups %}
<item>
<group>{{ group }}</group>
</item>
{% endfor %}
</createVolumePermission>
{% endif %}
<createVolumePermission>
{% for group in groups %}
<item>
<group>{{ group }}</group>
</item>
{% endfor %}
{% for userId in userIds %}
<item>
<userId>{{ userId }}</userId>
</item>
{% endfor %}
</createVolumePermission>
</DescribeSnapshotAttributeResponse>
"""

View File

@ -0,0 +1,122 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.models import validate_resource_ids
from moto.ec2.utils import filters_from_querystring
class FlowLogs(BaseResponse):
def create_flow_logs(self):
resource_type = self._get_param("ResourceType")
resource_ids = self._get_multi_param("ResourceId")
traffic_type = self._get_param("TrafficType")
deliver_logs_permission_arn = self._get_param("DeliverLogsPermissionArn")
log_destination_type = self._get_param("LogDestinationType")
log_destination = self._get_param("LogDestination")
log_group_name = self._get_param("LogGroupName")
log_format = self._get_param("LogFormat")
max_aggregation_interval = self._get_param("MaxAggregationInterval")
validate_resource_ids(resource_ids)
tags = self._parse_tag_specification("TagSpecification")
tags = tags.get("vpc-flow-log", {})
if self.is_not_dryrun("CreateFlowLogs"):
flow_logs, errors = self.ec2_backend.create_flow_logs(
resource_type=resource_type,
resource_ids=resource_ids,
traffic_type=traffic_type,
deliver_logs_permission_arn=deliver_logs_permission_arn,
log_destination_type=log_destination_type,
log_destination=log_destination,
log_group_name=log_group_name,
log_format=log_format,
max_aggregation_interval=max_aggregation_interval,
)
for fl in flow_logs:
fl.add_tags(tags)
template = self.response_template(CREATE_FLOW_LOGS_RESPONSE)
return template.render(flow_logs=flow_logs, errors=errors)
def describe_flow_logs(self):
flow_log_ids = self._get_multi_param("FlowLogId")
filters = filters_from_querystring(self.querystring)
flow_logs = self.ec2_backend.describe_flow_logs(flow_log_ids, filters)
if self.is_not_dryrun("DescribeFlowLogs"):
template = self.response_template(DESCRIBE_FLOW_LOGS_RESPONSE)
return template.render(flow_logs=flow_logs)
def delete_flow_logs(self):
flow_log_ids = self._get_multi_param("FlowLogId")
self.ec2_backend.delete_flow_logs(flow_log_ids)
if self.is_not_dryrun("DeleteFlowLogs"):
template = self.response_template(DELETE_FLOW_LOGS_RESPONSE)
return template.render()
CREATE_FLOW_LOGS_RESPONSE = """
<CreateFlowLogsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>2d96dae3-504b-4fc4-bf50-266EXAMPLE</requestId>
<unsuccessful>
{% for error in errors %}
<item>
<error>
<code>{{ error.1 }}</code>
<message>{{ error.2 }}</message>
</error>
<resourceId>{{ error.0 }}</resourceId>
</item>
{% endfor %}
</unsuccessful>
<flowLogIdSet>
{% for flow_log in flow_logs %}
<item>{{ flow_log.id }}</item>
{% endfor %}
</flowLogIdSet>
</CreateFlowLogsResponse>"""
DELETE_FLOW_LOGS_RESPONSE = """
<DeleteFlowLogsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>c5c4f51f-f4e9-42bc-8700-EXAMPLE</requestId>
<unsuccessful/>
</DeleteFlowLogsResponse>"""
DESCRIBE_FLOW_LOGS_RESPONSE = """
<DescribeFlowLogsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>3cb46f23-099e-4bf0-891c-EXAMPLE</requestId>
<flowLogSet>
{% for flow_log in flow_logs %}
<item>
{% if flow_log.log_destination is not none %}
<logDestination>{{ flow_log.log_destination }}</logDestination>
{% endif %}
<resourceId>{{ flow_log.resource_id }}</resourceId>
<logDestinationType>{{ flow_log.log_destination_type }}</logDestinationType>
<creationTime>{{ flow_log.created_at }}</creationTime>
<trafficType>{{ flow_log.traffic_type }}</trafficType>
<deliverLogsStatus>{{ flow_log.deliver_logs_status }}</deliverLogsStatus>
{% if flow_log.deliver_logs_error_message is not none %}
<deliverLogsErrorMessage>{{ flow_log.deliver_logs_error_message }}</deliverLogsErrorMessage>
{% endif %}
<logFormat>{{ flow_log.log_format }}</logFormat>
<flowLogStatus>ACTIVE</flowLogStatus>
<flowLogId>{{ flow_log.id }}</flowLogId>
<maxAggregationInterval>{{ flow_log.max_aggregation_interval }}</maxAggregationInterval>
{% if flow_log.deliver_logs_permission_arn is not none %}
<deliverLogsPermissionArn>{{ flow_log.deliver_logs_permission_arn }}</deliverLogsPermissionArn>
{% endif %}
{% if flow_log.log_group_name is not none %}
<logGroupName>{{ flow_log.log_group_name }}</logGroupName>
{% endif %}
{% if flow_log.get_tags() %}
<tagSet>
{% for tag in flow_log.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
{% endif %}
</item>
{% endfor %}
</flowLogSet>
</DescribeFlowLogsResponse>"""

View File

@ -0,0 +1,89 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
class IamInstanceProfiles(BaseResponse):
def associate_iam_instance_profile(self):
instance_id = self._get_param("InstanceId")
iam_instance_profile_name = self._get_param("IamInstanceProfile.Name")
iam_instance_profile_arn = self._get_param("IamInstanceProfile.Arn")
iam_association = self.ec2_backend.associate_iam_instance_profile(
instance_id, iam_instance_profile_name, iam_instance_profile_arn
)
template = self.response_template(IAM_INSTANCE_PROFILE_RESPONSE)
return template.render(iam_association=iam_association, state="associating")
def describe_iam_instance_profile_associations(self):
association_ids = self._get_multi_param("AssociationId")
filters = self._get_object_map("Filter")
max_items = self._get_param("MaxItems")
next_token = self._get_param("NextToken")
(
iam_associations,
next_token,
) = self.ec2_backend.describe_iam_instance_profile_associations(
association_ids, filters, max_items, next_token
)
template = self.response_template(DESCRIBE_IAM_INSTANCE_PROFILE_RESPONSE)
return template.render(iam_associations=iam_associations, next_token=next_token)
def disassociate_iam_instance_profile(self):
association_id = self._get_param("AssociationId")
iam_association = self.ec2_backend.disassociate_iam_instance_profile(
association_id
)
template = self.response_template(IAM_INSTANCE_PROFILE_RESPONSE)
return template.render(iam_association=iam_association, state="disassociating")
def replace_iam_instance_profile_association(self):
association_id = self._get_param("AssociationId")
iam_instance_profile_name = self._get_param("IamInstanceProfile.Name")
iam_instance_profile_arn = self._get_param("IamInstanceProfile.Arn")
iam_association = self.ec2_backend.replace_iam_instance_profile_association(
association_id, iam_instance_profile_name, iam_instance_profile_arn
)
template = self.response_template(IAM_INSTANCE_PROFILE_RESPONSE)
return template.render(iam_association=iam_association, state="associating")
# https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_AssociateIamInstanceProfile.html
IAM_INSTANCE_PROFILE_RESPONSE = """
<AssociateIamInstanceProfileResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>e10deeaf-7cda-48e7-950b-example</requestId>
<iamInstanceProfileAssociation>
<associationId>{{ iam_association.id }}</associationId>
{% if iam_association.iam_instance_profile %}
<iamInstanceProfile>
<arn>{{ iam_association.iam_instance_profile.arn }}</arn>
<id>{{ iam_association.iam_instance_profile.id }}</id>
</iamInstanceProfile>
{% endif %}
<instanceId>{{ iam_association.instance.id }}</instanceId>
<state>{{ state }}</state>
</iamInstanceProfileAssociation>
</AssociateIamInstanceProfileResponse>
"""
# https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeIamInstanceProfileAssociations.html
# Note: this API description page contains an error! Provided `iamInstanceProfileAssociations` doesn't work, you
# should use `iamInstanceProfileAssociationSet` instead.
DESCRIBE_IAM_INSTANCE_PROFILE_RESPONSE = """
<DescribeIamInstanceProfileAssociationsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>84c2d2a6-12dc-491f-a9ee-example</requestId>
{% if next_token %}<nextToken>{{ next_token }}</nextToken>{% endif %}
<iamInstanceProfileAssociationSet>
{% for iam_association in iam_associations %}
<item>
<associationId>{{ iam_association.id }}</associationId>
<iamInstanceProfile>
<arn>{{ iam_association.iam_instance_profile.arn }}</arn>
<id>{{ iam_association.iam_instance_profile.id }}</id>
</iamInstanceProfile>
<instanceId>{{ iam_association.instance.id }}</instanceId>
<state>{{ iam_association.state }}</state>
</item>
{% endfor %}
</iamInstanceProfileAssociationSet>
</DescribeIamInstanceProfileAssociationsResponse>
"""

View File

@ -1,13 +1,20 @@
from __future__ import unicode_literals
from boto.ec2.instancetype import InstanceType
from moto.packages.boto.ec2.instancetype import InstanceType
from moto.autoscaling import autoscaling_backends
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, dict_from_querystring
from moto.ec2.exceptions import MissingParameterError
from moto.ec2.utils import (
filters_from_querystring,
dict_from_querystring,
)
from moto.elbv2 import elbv2_backends
from moto.core import ACCOUNT_ID
from copy import deepcopy
import six
class InstanceResponse(BaseResponse):
def describe_instances(self):
@ -44,40 +51,31 @@ class InstanceResponse(BaseResponse):
owner_id = self._get_param("OwnerId")
user_data = self._get_param("UserData")
security_group_names = self._get_multi_param("SecurityGroup")
security_group_ids = self._get_multi_param("SecurityGroupId")
nics = dict_from_querystring("NetworkInterface", self.querystring)
instance_type = self._get_param("InstanceType", if_none="m1.small")
placement = self._get_param("Placement.AvailabilityZone")
subnet_id = self._get_param("SubnetId")
private_ip = self._get_param("PrivateIpAddress")
associate_public_ip = self._get_param("AssociatePublicIpAddress")
key_name = self._get_param("KeyName")
ebs_optimized = self._get_param("EbsOptimized")
instance_initiated_shutdown_behavior = self._get_param(
"InstanceInitiatedShutdownBehavior"
)
tags = self._parse_tag_specification("TagSpecification")
region_name = self.region
kwargs = {
"instance_type": self._get_param("InstanceType", if_none="m1.small"),
"placement": self._get_param("Placement.AvailabilityZone"),
"region_name": self.region,
"subnet_id": self._get_param("SubnetId"),
"owner_id": owner_id,
"key_name": self._get_param("KeyName"),
"security_group_ids": self._get_multi_param("SecurityGroupId"),
"nics": dict_from_querystring("NetworkInterface", self.querystring),
"private_ip": self._get_param("PrivateIpAddress"),
"associate_public_ip": self._get_param("AssociatePublicIpAddress"),
"tags": self._parse_tag_specification("TagSpecification"),
"ebs_optimized": self._get_param("EbsOptimized") or False,
"instance_initiated_shutdown_behavior": self._get_param(
"InstanceInitiatedShutdownBehavior"
),
}
mappings = self._parse_block_device_mapping()
if mappings:
kwargs["block_device_mappings"] = mappings
if self.is_not_dryrun("RunInstance"):
new_reservation = self.ec2_backend.add_instances(
image_id,
min_count,
user_data,
security_group_names,
instance_type=instance_type,
placement=placement,
region_name=region_name,
subnet_id=subnet_id,
owner_id=owner_id,
key_name=key_name,
security_group_ids=security_group_ids,
nics=nics,
private_ip=private_ip,
associate_public_ip=associate_public_ip,
tags=tags,
ebs_optimized=ebs_optimized,
instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior,
image_id, min_count, user_data, security_group_names, **kwargs
)
template = self.response_template(EC2_RUN_INSTANCES)
@ -113,16 +111,34 @@ class InstanceResponse(BaseResponse):
template = self.response_template(EC2_START_INSTANCES)
return template.render(instances=instances)
def _get_list_of_dict_params(self, param_prefix, _dct):
"""
Simplified version of _get_dict_param
Allows you to pass in a custom dict instead of using self.querystring by default
"""
params = []
for key, value in _dct.items():
if key.startswith(param_prefix):
params.append(value)
return params
def describe_instance_status(self):
instance_ids = self._get_multi_param("InstanceId")
include_all_instances = self._get_param("IncludeAllInstances") == "true"
filters = self._get_list_prefix("Filter")
filters = [
{"name": f["name"], "values": self._get_list_of_dict_params("value.", f)}
for f in filters
]
if instance_ids:
instances = self.ec2_backend.get_multi_instances_by_id(instance_ids)
instances = self.ec2_backend.get_multi_instances_by_id(
instance_ids, filters
)
elif include_all_instances:
instances = self.ec2_backend.all_instances()
instances = self.ec2_backend.all_instances(filters)
else:
instances = self.ec2_backend.all_running_instances()
instances = self.ec2_backend.all_running_instances(filters)
template = self.response_template(EC2_INSTANCE_STATUS)
return template.render(instances=instances)
@ -150,6 +166,14 @@ class InstanceResponse(BaseResponse):
return template.render(instance=instance, attribute=attribute, value=value)
def describe_instance_credit_specifications(self):
instance_ids = self._get_multi_param("InstanceId")
instance = self.ec2_backend.describe_instance_credit_specifications(
instance_ids
)
template = self.response_template(EC2_DESCRIBE_INSTANCE_CREDIT_SPECIFICATIONS)
return template.render(instances=instance)
def modify_instance_attribute(self):
handlers = [
self._dot_value_instance_attribute_handler,
@ -246,6 +270,68 @@ class InstanceResponse(BaseResponse):
)
return EC2_MODIFY_INSTANCE_ATTRIBUTE
def _parse_block_device_mapping(self):
device_mappings = self._get_list_prefix("BlockDeviceMapping")
mappings = []
for device_mapping in device_mappings:
self._validate_block_device_mapping(device_mapping)
device_template = deepcopy(BLOCK_DEVICE_MAPPING_TEMPLATE)
device_template["VirtualName"] = device_mapping.get("virtual_name")
device_template["DeviceName"] = device_mapping.get("device_name")
device_template["Ebs"]["SnapshotId"] = device_mapping.get(
"ebs._snapshot_id"
)
device_template["Ebs"]["VolumeSize"] = device_mapping.get(
"ebs._volume_size"
)
device_template["Ebs"]["DeleteOnTermination"] = self._convert_to_bool(
device_mapping.get("ebs._delete_on_termination", False)
)
device_template["Ebs"]["VolumeType"] = device_mapping.get(
"ebs._volume_type"
)
device_template["Ebs"]["Iops"] = device_mapping.get("ebs._iops")
device_template["Ebs"]["Encrypted"] = self._convert_to_bool(
device_mapping.get("ebs._encrypted", False)
)
mappings.append(device_template)
return mappings
@staticmethod
def _validate_block_device_mapping(device_mapping):
if not any(mapping for mapping in device_mapping if mapping.startswith("ebs.")):
raise MissingParameterError("ebs")
if (
"ebs._volume_size" not in device_mapping
and "ebs._snapshot_id" not in device_mapping
):
raise MissingParameterError("size or snapshotId")
@staticmethod
def _convert_to_bool(bool_str):
if isinstance(bool_str, bool):
return bool_str
if isinstance(bool_str, six.text_type):
return str(bool_str).lower() == "true"
return False
BLOCK_DEVICE_MAPPING_TEMPLATE = {
"VirtualName": None,
"DeviceName": None,
"Ebs": {
"SnapshotId": None,
"VolumeSize": None,
"DeleteOnTermination": None,
"VolumeType": None,
"Iops": None,
"Encrypted": None,
},
}
EC2_RUN_INSTANCES = (
"""<RunInstancesResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -653,6 +739,18 @@ EC2_DESCRIBE_INSTANCE_ATTRIBUTE = """<DescribeInstanceAttributeResponse xmlns="h
</{{ attribute }}>
</DescribeInstanceAttributeResponse>"""
EC2_DESCRIBE_INSTANCE_CREDIT_SPECIFICATIONS = """<DescribeInstanceCreditSpecificationsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>1b234b5c-d6ef-7gh8-90i1-j2345678901</requestId>
<instanceCreditSpecificationSet>
{% for instance in instances %}
<item>
<instanceId>{{ instance.id }}</instanceId>
<cpuCredits>standard</cpuCredits>
</item>
{% endfor %}
</instanceCreditSpecificationSet>
</DescribeInstanceCreditSpecificationsResponse>"""
EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE = """<DescribeInstanceAttributeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<instanceId>{{ instance.id }}</instanceId>
@ -720,13 +818,25 @@ EC2_DESCRIBE_INSTANCE_TYPES = """<?xml version="1.0" encoding="UTF-8"?>
<instanceTypeSet>
{% for instance_type in instance_types %}
<item>
<name>{{ instance_type.name }}</name>
<vcpu>{{ instance_type.cores }}</vcpu>
<memory>{{ instance_type.memory }}</memory>
<storageSize>{{ instance_type.disk }}</storageSize>
<storageCount>{{ instance_type.storageCount }}</storageCount>
<maxIpAddresses>{{ instance_type.maxIpAddresses }}</maxIpAddresses>
<ebsOptimizedAvailable>{{ instance_type.ebsOptimizedAvailable }}</ebsOptimizedAvailable>
<instanceType>{{ instance_type.name }}</instanceType>
<vCpuInfo>
<defaultVCpus>{{ instance_type.cores }}</defaultVCpus>
<defaultCores>{{ instance_type.cores }}</defaultCores>
<defaultThreadsPerCore>1</defaultThreadsPerCore>
</vCpuInfo>
<memoryInfo>
<sizeInMiB>{{ instance_type.memory }}</sizeInMiB>
</memoryInfo>
<instanceStorageInfo>
<totalSizeInGB>{{ instance_type.disk }}</totalSizeInGB>
</instanceStorageInfo>
<processorInfo>
<supportedArchitectures>
<item>
x86_64
</item>
</supportedArchitectures>
</processorInfo>
</item>
{% endfor %}
</instanceTypeSet>

View File

@ -14,7 +14,10 @@ class InternetGateways(BaseResponse):
def create_internet_gateway(self):
if self.is_not_dryrun("CreateInternetGateway"):
igw = self.ec2_backend.create_internet_gateway()
tags = self._get_multi_param("TagSpecification")
if tags:
tags = tags[0].get("Tag")
igw = self.ec2_backend.create_internet_gateway(tags=tags)
template = self.response_template(CREATE_INTERNET_GATEWAY_RESPONSE)
return template.render(internet_gateway=igw)

View File

@ -6,7 +6,10 @@ from moto.ec2.utils import filters_from_querystring
class NetworkACLs(BaseResponse):
def create_network_acl(self):
vpc_id = self._get_param("VpcId")
network_acl = self.ec2_backend.create_network_acl(vpc_id)
tags = self._get_multi_param("TagSpecification")
if tags:
tags = tags[0].get("Tag")
network_acl = self.ec2_backend.create_network_acl(vpc_id, tags=tags)
template = self.response_template(CREATE_NETWORK_ACL_RESPONSE)
return template.render(network_acl=network_acl)
@ -83,7 +86,7 @@ class NetworkACLs(BaseResponse):
def describe_network_acls(self):
network_acl_ids = self._get_multi_param("NetworkAclId")
filters = filters_from_querystring(self.querystring)
network_acls = self.ec2_backend.get_all_network_acls(network_acl_ids, filters)
network_acls = self.ec2_backend.describe_network_acls(network_acl_ids, filters)
template = self.response_template(DESCRIBE_NETWORK_ACL_RESPONSE)
return template.render(network_acls=network_acls)
@ -161,7 +164,7 @@ DESCRIBE_NETWORK_ACL_RESPONSE = """
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<key>{{ tag.key}}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}

View File

@ -16,6 +16,7 @@ class RouteTables(BaseResponse):
def create_route(self):
route_table_id = self._get_param("RouteTableId")
destination_cidr_block = self._get_param("DestinationCidrBlock")
destination_ipv6_cidr_block = self._get_param("DestinationIpv6CidrBlock")
gateway_id = self._get_param("GatewayId")
instance_id = self._get_param("InstanceId")
nat_gateway_id = self._get_param("NatGatewayId")
@ -25,6 +26,7 @@ class RouteTables(BaseResponse):
self.ec2_backend.create_route(
route_table_id,
destination_cidr_block,
destination_ipv6_cidr_block,
gateway_id=gateway_id,
instance_id=instance_id,
nat_gateway_id=nat_gateway_id,
@ -37,7 +39,10 @@ class RouteTables(BaseResponse):
def create_route_table(self):
vpc_id = self._get_param("VpcId")
route_table = self.ec2_backend.create_route_table(vpc_id)
tags = self._get_multi_param("TagSpecification")
if tags:
tags = tags[0].get("Tag")
route_table = self.ec2_backend.create_route_table(vpc_id, tags)
template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE)
return template.render(route_table=route_table)

View File

@ -20,7 +20,11 @@ def parse_sg_attributes_from_dict(sg_attributes):
ip_ranges = []
ip_ranges_tree = sg_attributes.get("IpRanges") or {}
for ip_range_idx in sorted(ip_ranges_tree.keys()):
ip_ranges.append(ip_ranges_tree[ip_range_idx]["CidrIp"][0])
ip_range = {"CidrIp": ip_ranges_tree[ip_range_idx]["CidrIp"][0]}
if ip_ranges_tree[ip_range_idx].get("Description"):
ip_range["Description"] = ip_ranges_tree[ip_range_idx].get("Description")[0]
ip_ranges.append(ip_range)
source_groups = []
source_group_ids = []
@ -61,6 +65,7 @@ class SecurityGroups(BaseResponse):
source_groups,
source_group_ids,
) = parse_sg_attributes_from_dict(querytree)
yield (
group_name_or_id,
ip_protocol,
@ -211,7 +216,10 @@ DESCRIBE_SECURITY_GROUPS_RESPONSE = (
<ipRanges>
{% for ip_range in rule.ip_ranges %}
<item>
<cidrIp>{{ ip_range }}</cidrIp>
<cidrIp>{{ ip_range['CidrIp'] }}</cidrIp>
{% if ip_range['Description'] %}
<description>{{ ip_range['Description'] }}</description>
{% endif %}
</item>
{% endfor %}
</ipRanges>
@ -242,7 +250,10 @@ DESCRIBE_SECURITY_GROUPS_RESPONSE = (
<ipRanges>
{% for ip_range in rule.ip_ranges %}
<item>
<cidrIp>{{ ip_range }}</cidrIp>
<cidrIp>{{ ip_range['CidrIp'] }}</cidrIp>
{% if ip_range['Description'] %}
<description>{{ ip_range['Description'] }}</description>
{% endif %}
</item>
{% endfor %}
</ipRanges>

View File

@ -9,12 +9,23 @@ class Subnets(BaseResponse):
def create_subnet(self):
vpc_id = self._get_param("VpcId")
cidr_block = self._get_param("CidrBlock")
availability_zone = self._get_param(
"AvailabilityZone",
if_none=random.choice(self.ec2_backend.describe_availability_zones()).name,
)
availability_zone = self._get_param("AvailabilityZone")
availability_zone_id = self._get_param("AvailabilityZoneId")
tags = self._get_multi_param("TagSpecification")
if tags:
tags = tags[0].get("Tag")
if not availability_zone and not availability_zone_id:
availability_zone = random.choice(
self.ec2_backend.describe_availability_zones()
).name
subnet = self.ec2_backend.create_subnet(
vpc_id, cidr_block, availability_zone, context=self
vpc_id,
cidr_block,
availability_zone,
availability_zone_id,
context=self,
tags=tags,
)
template = self.response_template(CREATE_SUBNET_RESPONSE)
return template.render(subnet=subnet)
@ -62,6 +73,16 @@ CREATE_SUBNET_RESPONSE = """
<assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation>
<ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
<tagSet>
{% for tag in subnet.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</subnet>
</CreateSubnetResponse>"""
@ -78,7 +99,7 @@ DESCRIBE_SUBNETS_RESPONSE = """
{% for subnet in subnets %}
<item>
<subnetId>{{ subnet.id }}</subnetId>
<state>available</state>
<state>{{ subnet.state }}</state>
<vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>{{ subnet.available_ip_addresses }}</availableIpAddressCount>

View File

@ -2,7 +2,8 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.models import validate_resource_ids
from moto.ec2.utils import tags_from_query_string, filters_from_querystring
from moto.ec2.utils import filters_from_querystring
from moto.core.utils import tags_from_query_string
class TagResponse(BaseResponse):

View File

@ -86,6 +86,7 @@ DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = (
<ownerId>777788889999</ownerId>
<vpcId>{{ vpc_pcx.vpc.id }}</vpcId>
<cidrBlock>{{ vpc_pcx.vpc.cidr_block }}</cidrBlock>
<region>{{ vpc_pcx.vpc.ec2_backend.region_name }}</region>
</requesterVpcInfo>
<accepterVpcInfo>
<ownerId>"""
@ -98,6 +99,7 @@ DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = (
<allowEgressFromLocalVpcToRemoteClassicLink>true</allowEgressFromLocalVpcToRemoteClassicLink>
<allowDnsResolutionFromRemoteVpc>false</allowDnsResolutionFromRemoteVpc>
</peeringOptions>
<region>{{ vpc_pcx.peer_vpc.ec2_backend.region_name }}</region>
</accepterVpcInfo>
<status>
<code>{{ vpc_pcx._status.code }}</code>
@ -128,6 +130,7 @@ ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = (
<ownerId>777788889999</ownerId>
<vpcId>{{ vpc_pcx.vpc.id }}</vpcId>
<cidrBlock>{{ vpc_pcx.vpc.cidr_block }}</cidrBlock>
<region>{{ vpc_pcx.vpc.ec2_backend.region_name }}</region>
</requesterVpcInfo>
<accepterVpcInfo>
<ownerId>"""
@ -140,6 +143,7 @@ ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = (
<allowEgressFromLocalVpcToRemoteClassicLink>false</allowEgressFromLocalVpcToRemoteClassicLink>
<allowDnsResolutionFromRemoteVpc>false</allowDnsResolutionFromRemoteVpc>
</peeringOptions>
<region>{{ vpc_pcx.peer_vpc.ec2_backend.region_name }}</region>
</accepterVpcInfo>
<status>
<code>{{ vpc_pcx._status.code }}</code>

View File

@ -14,14 +14,19 @@ class VPCs(BaseResponse):
def create_vpc(self):
cidr_block = self._get_param("CidrBlock")
tags = self._get_multi_param("TagSpecification")
instance_tenancy = self._get_param("InstanceTenancy", if_none="default")
amazon_provided_ipv6_cidr_blocks = self._get_param(
"AmazonProvidedIpv6CidrBlock"
)
if tags:
tags = tags[0].get("Tag")
vpc = self.ec2_backend.create_vpc(
cidr_block,
instance_tenancy,
amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_blocks,
tags=tags,
)
doc_date = self._get_doc_date()
template = self.response_template(CREATE_VPC_RESPONSE)
@ -163,6 +168,39 @@ class VPCs(BaseResponse):
cidr_block_state="disassociating",
)
def create_vpc_endpoint(self):
vpc_id = self._get_param("VpcId")
service_name = self._get_param("ServiceName")
route_table_ids = self._get_multi_param("RouteTableId")
subnet_ids = self._get_multi_param("SubnetId")
type = self._get_param("VpcEndpointType")
policy_document = self._get_param("PolicyDocument")
client_token = self._get_param("ClientToken")
tag_specifications = self._get_param("TagSpecifications")
private_dns_enabled = self._get_param("PrivateDNSEnabled")
security_group = self._get_param("SecurityGroup")
vpc_end_point = self.ec2_backend.create_vpc_endpoint(
vpc_id=vpc_id,
service_name=service_name,
type=type,
policy_document=policy_document,
route_table_ids=route_table_ids,
subnet_ids=subnet_ids,
client_token=client_token,
security_group=security_group,
tag_specifications=tag_specifications,
private_dns_enabled=private_dns_enabled,
)
template = self.response_template(CREATE_VPC_END_POINT)
return template.render(vpc_end_point=vpc_end_point)
def describe_vpc_endpoint_services(self):
vpc_end_point_services = self.ec2_backend.get_vpc_end_point_services()
template = self.response_template(DESCRIBE_VPC_ENDPOINT_RESPONSE)
return template.render(vpc_end_points=vpc_end_point_services)
CREATE_VPC_RESPONSE = """
<CreateVpcResponse xmlns="http://ec2.amazonaws.com/doc/{{doc_date}}/">
@ -384,3 +422,72 @@ IPV6_DISASSOCIATE_VPC_CIDR_BLOCK_RESPONSE = """
</ipv6CidrBlockState>
</ipv6CidrBlockAssociation>
</DisassociateVpcCidrBlockResponse>"""
CREATE_VPC_END_POINT = """ <CreateVpcEndpointResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<vpcEndpoint>
<policyDocument>{{ vpc_end_point.policy_document }}</policyDocument>
<state> available </state>
<vpcEndpointPolicySupported> false </vpcEndpointPolicySupported>
<serviceName>{{ vpc_end_point.service_name }}</serviceName>
<vpcId>{{ vpc_end_point.vpc_id }}</vpcId>
<vpcEndpointId>{{ vpc_end_point.id }}</vpcEndpointId>
<routeTableIdSet>
{% for routeid in vpc_end_point.route_table_ids %}
<item>{{ routeid }}</item>
{% endfor %}
</routeTableIdSet>
<networkInterfaceIdSet>
{% for network_interface_id in vpc_end_point.network_interface_ids %}
<item>{{ network_interface_id }}</item>
{% endfor %}
</networkInterfaceIdSet>
<subnetIdSet>
{% for subnetId in vpc_end_point.subnet_ids %}
<item>{{ subnetId }}</item>
{% endfor %}
</subnetIdSet>
<dnsEntrySet>
{% if vpc_end_point.dns_entries %}
{% for entry in vpc_end_point.dns_entries %}
<item>
<hostedZoneId>{{ entry["hosted_zone_id"] }}</hostedZoneId>
<dnsName>{{ entry["dns_name"] }}</dnsName>
</item>
{% endfor %}
{% endif %}
</dnsEntrySet>
<creationTimestamp>{{ vpc_end_point.created_at }}</creationTimestamp>
</vpcEndpoint>
</CreateVpcEndpointResponse>"""
DESCRIBE_VPC_ENDPOINT_RESPONSE = """<DescribeVpcEndpointServicesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>19a9ff46-7df6-49b8-9726-3df27527089d</requestId>
<serviceNameSet>
{% for serviceName in vpc_end_points.services %}
<item>{{ serviceName }}</item>
{% endfor %}
</serviceNameSet>
<serviceDetailSet>
<item>
{% for service in vpc_end_points.servicesDetails %}
<owner>amazon</owner>
<serviceType>
<item>
<serviceType>{{ service.type }}</serviceType>
</item>
</serviceType>
<baseEndpointDnsNameSet>
<item>{{ ".".join((service.service_name.split(".")[::-1])) }}</item>
</baseEndpointDnsNameSet>
<acceptanceRequired>false</acceptanceRequired>
<availabilityZoneSet>
{% for zone in vpc_end_points.availability_zones %}
<item>{{ zone.name }}</item>
{% endfor %}
</availabilityZoneSet>
<serviceName>{{ service.service_name }}</serviceName>
<vpcEndpointPolicySupported>true</vpcEndpointPolicySupported>
{% endfor %}
</item>
</serviceDetailSet>
</DescribeVpcEndpointServicesResponse>"""

View File

@ -7,7 +7,7 @@ class VPNConnections(BaseResponse):
def create_vpn_connection(self):
type = self._get_param("Type")
cgw_id = self._get_param("CustomerGatewayId")
vgw_id = self._get_param("VPNGatewayId")
vgw_id = self._get_param("VpnGatewayId")
static_routes = self._get_param("StaticRoutesOnly")
vpn_connection = self.ec2_backend.create_vpn_connection(
type, cgw_id, vgw_id, static_routes_only=static_routes

Some files were not shown because too many files have changed in this diff Show More