Merge remote-tracking branch 'upstream/master'

This commit is contained in:
mickeypash 2017-10-23 19:16:08 +01:00
commit 15a1bc43a9
220 changed files with 34437 additions and 1657 deletions

1
.gitignore vendored
View File

@ -12,3 +12,4 @@ build/
*.swp *.swp
.DS_Store .DS_Store
python_env python_env
.ropeproject/

View File

@ -1,23 +1,36 @@
language: python language: python
sudo: false sudo: false
services:
- docker
python: python:
- 2.7 - 2.7
- 3.6 - 3.6
env: env:
- TEST_SERVER_MODE=false - TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true - TEST_SERVER_MODE=true
before_install:
- export BOTO_CONFIG=/dev/null
install: install:
- travis_retry pip install boto==2.45.0 # We build moto first so the docker container doesn't try to compile it as well, also note we don't use
- travis_retry pip install boto3 # -d for docker run so the logs show up in travis
- travis_retry pip install . # Python images come from here: https://hub.docker.com/_/python/
- travis_retry pip install -r requirements-dev.txt
- travis_retry pip install coveralls
- | - |
python setup.py sdist
if [ "$TEST_SERVER_MODE" = "true" ]; then if [ "$TEST_SERVER_MODE" = "true" ]; then
AWS_SECRET_ACCESS_KEY=server_secret AWS_ACCESS_KEY_ID=server_key moto_server -p 5000& docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh &
export AWS_SECRET_ACCESS_KEY=foobar_secret export AWS_SECRET_ACCESS_KEY=foobar_secret
export AWS_ACCESS_KEY_ID=foobar_key export AWS_ACCESS_KEY_ID=foobar_key
fi fi
travis_retry pip install boto==2.45.0
travis_retry pip install boto3
travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt
if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py
fi
script: script:
- make test - make test
after_success: after_success:

View File

@ -3,6 +3,141 @@ Moto Changelog
Latest Latest
------ ------
1.1.22
-----
* Lambda policies
* Dynamodb filter expressions
* EC2 Spot fleet improvements
1.1.21
-----
* ELBv2 bugfixes
* Removing GPL'd dependency
1.1.20
-----
* Improved `make scaffold`
* Implemented IAM attached group policies
* Implemented skeleton of Cloudwatch Logs
* Redshift: fixed multi-params
* Redshift: implement taggable resources
* Lambda + SNS: Major enhancements
1.1.19
-----
* Fixing regression from 1.1.15
1.1.15
-----
* Polly implementation
* Added EC2 instance info
* SNS publish by phone number
1.1.14
-----
* ACM implementation
* Added `make scaffold`
* X-Ray implementation
1.1.13
-----
* Created alpine-based Dockerfile (dockerhub: motoserver/moto)
* SNS.SetSMSAttributes & SNS.GetSMSAttributes + Filtering
* S3 ACL implementation
* pushing to Dockerhub on `make publish`
1.1.12
-----
* implemented all AWS managed policies in source
* fixing Dynamodb CapacityUnits format
* S3 ACL implementation
1.1.11
-----
* S3 authentication
* SSM get_parameter
* ELBv2 target group tagging
* EC2 Security group filters
1.1.10
-----
* EC2 vpc address filtering
* EC2 elastic ip dissociation
* ELBv2 target group tagging
* fixed complexity of accepting new filter implementations
1.1.9
-----
* EC2 root device mapping
1.1.8
-----
* Lambda get_function for function created with zipfile
* scripts/implementation_coverage.py
1.1.7
-----
* Lambda invoke_async
* EC2 keypair filtering
1.1.6
-----
* Dynamo ADD and DELETE operations in update expressions
* Lambda tag support
1.1.5
-----
* Dynamo allow ADD update_item of a string set
* Handle max-keys in list-objects
* bugfixes in pagination
1.1.3
-----
* EC2 vpc_id in responses
1.1.2
-----
* IAM account aliases
* SNS subscription attributes
* bugfixes in Dynamo, CFN, and EC2
1.1.1
-----
* EC2 group-id filter
* EC2 list support for filters
1.1.0
-----
* Add ELBv2
* IAM user policies
* RDS snapshots
* IAM policy versions
1.0.1
-----
* Add Cloudformation exports
* Add ECR
* IAM policy versions
1.0.0 1.0.0
----- -----

View File

@ -1,11 +1,22 @@
FROM python:2 FROM alpine:3.6
RUN apk add --no-cache --update \
gcc \
musl-dev \
python3-dev \
libffi-dev \
openssl-dev \
python3
ADD . /moto/ ADD . /moto/
ENV PYTHONUNBUFFERED 1 ENV PYTHONUNBUFFERED 1
WORKDIR /moto/ WORKDIR /moto/
RUN pip install ".[server]" RUN python3 -m ensurepip && \
rm -r /usr/lib/python*/ensurepip && \
pip3 --no-cache-dir install --upgrade pip setuptools && \
pip3 --no-cache-dir install ".[server]"
CMD ["moto_server"] ENTRYPOINT ["/usr/bin/moto_server", "-H", "0.0.0.0"]
EXPOSE 5000 EXPOSE 5000

View File

@ -1,3 +1,5 @@
include README.md LICENSE AUTHORS.md include README.md LICENSE AUTHORS.md
include requirements.txt requirements-dev.txt tox.ini include requirements.txt requirements-dev.txt tox.ini
include moto/ec2/resources/instance_types.json
recursive-include moto/templates *
recursive-include tests * recursive-include tests *

View File

@ -15,5 +15,22 @@ test: lint
test_server: test_server:
@TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/ @TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/
publish: aws_managed_policies:
scripts/update_managed_policies.py
upload_pypi_artifact:
python setup.py sdist bdist_wheel upload python setup.py sdist bdist_wheel upload
push_dockerhub_image:
docker build -t motoserver/moto .
docker push motoserver/moto
tag_github_release:
git tag `python setup.py --version`
git push origin `python setup.py --version`
publish: upload_pypi_artifact push_dockerhub_image tag_github_release
scaffold:
@pip install -r requirements-dev.txt > /dev/null
exec python scripts/scaffold.py

View File

@ -58,6 +58,8 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| Service Name | Decorator | Development Status | | Service Name | Decorator | Development Status |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| ACM | @mock_acm | all endpoints done |
|------------------------------------------------------------------------------|
| API Gateway | @mock_apigateway | core endpoints done | | API Gateway | @mock_apigateway | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| Autoscaling | @mock_autoscaling| core endpoints done | | Autoscaling | @mock_autoscaling| core endpoints done |
@ -78,22 +80,31 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
| - Security Groups | | core endpoints done | | - Security Groups | | core endpoints done |
| - Tags | | all endpoints done | | - Tags | | all endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| ECR | @mock_ecr | basic endpoints done |
|------------------------------------------------------------------------------|
| ECS | @mock_ecs | basic endpoints done | | ECS | @mock_ecs | basic endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| ELB | @mock_elb | core endpoints done | | ELB | @mock_elb | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| ELBv2 | @mock_elbv2 | core endpoints done |
|------------------------------------------------------------------------------|
| EMR | @mock_emr | core endpoints done | | EMR | @mock_emr | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| Glacier | @mock_glacier | core endpoints done | | Glacier | @mock_glacier | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| IAM | @mock_iam | core endpoints done | | IAM | @mock_iam | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| Lambda | @mock_lambda | basic endpoints done | | Lambda | @mock_lambda | basic endpoints done, requires |
| | | docker |
|------------------------------------------------------------------------------|
| Logs | @mock_logs | basic endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| Kinesis | @mock_kinesis | core endpoints done | | Kinesis | @mock_kinesis | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| KMS | @mock_kms | basic endpoints done | | KMS | @mock_kms | basic endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| Polly | @mock_polly | all endpoints done |
|------------------------------------------------------------------------------|
| RDS | @mock_rds | core endpoints done | | RDS | @mock_rds | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| RDS2 | @mock_rds2 | core endpoints done | | RDS2 | @mock_rds2 | core endpoints done |
@ -106,7 +117,7 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| SES | @mock_ses | core endpoints done | | SES | @mock_ses | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| SNS | @mock_sns | core endpoints done | | SNS | @mock_sns | all endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| SQS | @mock_sqs | core endpoints done | | SQS | @mock_sqs | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
@ -114,7 +125,9 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| STS | @mock_sts | core endpoints done | | STS | @mock_sts | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| SWF | @mock_sfw | basic endpoints done | | SWF | @mock_swf | basic endpoints done |
|------------------------------------------------------------------------------|
| X-Ray | @mock_xray | core endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
``` ```
@ -123,28 +136,51 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
Imagine you have a function that you use to launch new ec2 instances: Imagine you have a function that you use to launch new ec2 instances:
```python ```python
import boto import boto3
def add_servers(ami_id, count): def add_servers(ami_id, count):
conn = boto.connect_ec2('the_key', 'the_secret') client = boto3.client('ec2', region_name='us-west-1')
for index in range(count): client.run_instances(ImageId=ami_id, MinCount=count, MaxCount=count)
conn.run_instances(ami_id)
``` ```
To test it: To test it:
```python ```python
from . import add_servers from . import add_servers
from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_add_servers(): def test_add_servers():
add_servers('ami-1234abcd', 2) add_servers('ami-1234abcd', 2)
conn = boto.connect_ec2('the_key', 'the_secret') client = boto3.client('ec2', region_name='us-west-1')
reservations = conn.get_all_instances() instances = client.describe_instances()['Reservations'][0]['Instances']
assert len(reservations) == 2 assert len(instances) == 2
instance1 = reservations[0].instances[0] instance1 = instances[0]
assert instance1.image_id == 'ami-1234abcd' assert instance1['ImageId'] == 'ami-1234abcd'
```
#### Using moto 1.0.X with boto2
moto 1.0.X mock docorators are defined for boto3 and do not work with boto2. Use the @mock_AWSSVC_deprecated to work with boto2.
Using moto with boto2
```python
from moto import mock_ec2_deprecated
import boto
@mock_ec2_deprecated
def test_something_with_ec2():
ec2_conn = boto.ec2.connect_to_region('us-east-1')
ec2_conn.get_only_instances(instance_ids='i-123456')
```
When using both boto2 and boto3, one can do this to avoid confusion:
```python
from moto import mock_ec2_deprecated as mock_ec2_b2
from moto import mock_ec2
``` ```
## Usage ## Usage
@ -156,13 +192,14 @@ All of the services can be used as a decorator, context manager, or in a raw for
```python ```python
@mock_s3 @mock_s3
def test_my_model_save(): def test_my_model_save():
conn = boto.connect_s3() # Create Bucket so that test can run
conn.create_bucket('mybucket') conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
body = conn.Object('mybucket', 'steve').get()['Body'].read().decode()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' assert body == 'is awesome'
``` ```
### Context Manager ### Context Manager
@ -170,13 +207,13 @@ def test_my_model_save():
```python ```python
def test_my_model_save(): def test_my_model_save():
with mock_s3(): with mock_s3():
conn = boto.connect_s3() conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket('mybucket') conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
body = conn.Object('mybucket', 'steve').get()['Body'].read().decode()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' assert body == 'is awesome'
``` ```
@ -187,13 +224,13 @@ def test_my_model_save():
mock = mock_s3() mock = mock_s3()
mock.start() mock.start()
conn = boto.connect_s3() conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket('mybucket') conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' assert conn.Object('mybucket', 'steve').get()['Body'].read().decode() == 'is awesome'
mock.stop() mock.stop()
``` ```

View File

@ -74,7 +74,7 @@ Currently implemented Services:
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| STS | @mock_sts | core endpoints done | | STS | @mock_sts | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| SWF | @mock_sfw | basic endpoints done | | SWF | @mock_swf | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+

View File

@ -43,6 +43,7 @@ Currently implemented Services:
| ECS | @mock_ecs | basic endpoints done | | ECS | @mock_ecs | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| ELB | @mock_elb | core endpoints done | | ELB | @mock_elb | core endpoints done |
| | @mock_elbv2 | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| EMR | @mock_emr | core endpoints done | | EMR | @mock_emr | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
@ -74,7 +75,7 @@ Currently implemented Services:
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| STS | @mock_sts | core endpoints done | | STS | @mock_sts | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| SWF | @mock_sfw | basic endpoints done | | SWF | @mock_swf | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+

View File

@ -3,8 +3,9 @@ import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL) # logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto' __title__ = 'moto'
__version__ = '1.0.0' __version__ = '1.0.1'
from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # flake8: noqa from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # flake8: noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # flake8: noqa from .awslambda import mock_lambda, mock_lambda_deprecated # flake8: noqa
@ -14,15 +15,18 @@ from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # fla
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # flake8: noqa from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # flake8: noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # flake8: noqa from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # flake8: noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # flake8: noqa from .ec2 import mock_ec2, mock_ec2_deprecated # flake8: noqa
from .ecr import mock_ecr, mock_ecr_deprecated # flake8: noqa
from .ecs import mock_ecs, mock_ecs_deprecated # flake8: noqa from .ecs import mock_ecs, mock_ecs_deprecated # flake8: noqa
from .elb import mock_elb, mock_elb_deprecated # flake8: noqa from .elb import mock_elb, mock_elb_deprecated # flake8: noqa
from .elbv2 import mock_elbv2 # flake8: noqa
from .emr import mock_emr, mock_emr_deprecated # flake8: noqa from .emr import mock_emr, mock_emr_deprecated # flake8: noqa
from .events import mock_events # flake8: noqa from .events import mock_events # flake8: noqa
from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .iam import mock_iam, mock_iam_deprecated # flake8: noqa from .iam import mock_iam, mock_iam_deprecated # flake8: noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa
from .kms import mock_kms, mock_kms_deprecated # flake8: noqa from .kms import mock_kms, mock_kms_deprecated # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .polly import mock_polly # flake8: noqa
from .rds import mock_rds, mock_rds_deprecated # flake8: noqa from .rds import mock_rds, mock_rds_deprecated # flake8: noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa
from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa
@ -34,6 +38,9 @@ from .sts import mock_sts, mock_sts_deprecated # flake8: noqa
from .ssm import mock_ssm # flake8: noqa from .ssm import mock_ssm # flake8: noqa
from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa
from .swf import mock_swf, mock_swf_deprecated # flake8: noqa from .swf import mock_swf, mock_swf_deprecated # flake8: noqa
from .xray import mock_xray, mock_xray_client, XRaySegment # flake8: noqa
from .logs import mock_logs, mock_logs_deprecated # flake8: noqa
from .batch import mock_batch # flake8: noqa
try: try:

6
moto/acm/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import acm_backends
from ..core.models import base_decorator
acm_backend = acm_backends['us-east-1']
mock_acm = base_decorator(acm_backends)

395
moto/acm/models.py Normal file
View File

@ -0,0 +1,395 @@
from __future__ import unicode_literals
import re
import json
import datetime
from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends
from .utils import make_arn_for_certificate
import cryptography.x509
import cryptography.hazmat.primitives.asymmetric.rsa
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.backends import default_backend
DEFAULT_ACCOUNT_ID = 123456789012
GOOGLE_ROOT_CA = b"""-----BEGIN CERTIFICATE-----
MIIEKDCCAxCgAwIBAgIQAQAhJYiw+lmnd+8Fe2Yn3zANBgkqhkiG9w0BAQsFADBC
MQswCQYDVQQGEwJVUzEWMBQGA1UEChMNR2VvVHJ1c3QgSW5jLjEbMBkGA1UEAxMS
R2VvVHJ1c3QgR2xvYmFsIENBMB4XDTE3MDUyMjExMzIzN1oXDTE4MTIzMTIzNTk1
OVowSTELMAkGA1UEBhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMT
HEdvb2dsZSBJbnRlcm5ldCBBdXRob3JpdHkgRzIwggEiMA0GCSqGSIb3DQEBAQUA
A4IBDwAwggEKAoIBAQCcKgR3XNhQkToGo4Lg2FBIvIk/8RlwGohGfuCPxfGJziHu
Wv5hDbcyRImgdAtTT1WkzoJile7rWV/G4QWAEsRelD+8W0g49FP3JOb7kekVxM/0
Uw30SvyfVN59vqBrb4fA0FAfKDADQNoIc1Fsf/86PKc3Bo69SxEE630k3ub5/DFx
+5TVYPMuSq9C0svqxGoassxT3RVLix/IGWEfzZ2oPmMrhDVpZYTIGcVGIvhTlb7j
gEoQxirsupcgEcc5mRAEoPBhepUljE5SdeK27QjKFPzOImqzTs9GA5eXA37Asd57
r0Uzz7o+cbfe9CUlwg01iZ2d+w4ReYkeN8WvjnJpAgMBAAGjggERMIIBDTAfBgNV
HSMEGDAWgBTAephojYn7qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1
dvWBtrtiGrpagS8wDgYDVR0PAQH/BAQDAgEGMC4GCCsGAQUFBwEBBCIwIDAeBggr
BgEFBQcwAYYSaHR0cDovL2cuc3ltY2QuY29tMBIGA1UdEwEB/wQIMAYBAf8CAQAw
NQYDVR0fBC4wLDAqoCigJoYkaHR0cDovL2cuc3ltY2IuY29tL2NybHMvZ3RnbG9i
YWwuY3JsMCEGA1UdIAQaMBgwDAYKKwYBBAHWeQIFATAIBgZngQwBAgIwHQYDVR0l
BBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMA0GCSqGSIb3DQEBCwUAA4IBAQDKSeWs
12Rkd1u+cfrP9B4jx5ppY1Rf60zWGSgjZGaOHMeHgGRfBIsmr5jfCnC8vBk97nsz
qX+99AXUcLsFJnnqmseYuQcZZTTMPOk/xQH6bwx+23pwXEz+LQDwyr4tjrSogPsB
E4jLnD/lu3fKOmc2887VJwJyQ6C9bgLxRwVxPgFZ6RGeGvOED4Cmong1L7bHon8X
fOGLVq7uZ4hRJzBgpWJSwzfVO+qFKgE4h6LPcK2kesnE58rF2rwjMvL+GMJ74N87
L9TQEOaWTPtEtyFkDbkAlDASJodYmDkFOA/MgkgMCkdm7r+0X8T/cKjhf4t5K7hl
MqO5tzHpCvX2HzLc
-----END CERTIFICATE-----"""
# Added google root CA as AWS returns chain you gave it + root CA (provided or not)
# so for now a cheap response is just give any old root CA
def datetime_to_epoch(date):
# As only Py3 has datetime.timestamp()
return int((date - datetime.datetime(1970, 1, 1)).total_seconds())
class AWSError(Exception):
TYPE = None
STATUS = 400
def __init__(self, message):
self.message = message
def response(self):
resp = {'__type': self.TYPE, 'message': self.message}
return json.dumps(resp), dict(status=self.STATUS)
class AWSValidationException(AWSError):
TYPE = 'ValidationException'
class AWSResourceNotFoundException(AWSError):
TYPE = 'ResourceNotFoundException'
class CertBundle(BaseModel):
def __init__(self, certificate, private_key, chain=None, region='us-east-1', arn=None, cert_type='IMPORTED', cert_status='ISSUED'):
self.created_at = datetime.datetime.now()
self.cert = certificate
self._cert = None
self.common_name = None
self.key = private_key
self._key = None
self.chain = chain
self.tags = {}
self._chain = None
self.type = cert_type # Should really be an enum
self.status = cert_status # Should really be an enum
# AWS always returns your chain + root CA
if self.chain is None:
self.chain = GOOGLE_ROOT_CA
else:
self.chain += b'\n' + GOOGLE_ROOT_CA
# Takes care of PEM checking
self.validate_pk()
self.validate_certificate()
if chain is not None:
self.validate_chain()
# TODO check cert is valid, or if self-signed then a chain is provided, otherwise
# raise AWSValidationException('Provided certificate is not a valid self signed. Please provide either a valid self-signed certificate or certificate chain.')
# Used for when one wants to overwrite an arn
if arn is None:
self.arn = make_arn_for_certificate(DEFAULT_ACCOUNT_ID, region)
else:
self.arn = arn
@classmethod
def generate_cert(cls, domain_name, sans=None):
if sans is None:
sans = set()
else:
sans = set(sans)
sans.add(domain_name)
sans = [cryptography.x509.DNSName(item) for item in sans]
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend())
subject = cryptography.x509.Name([
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, u"CA"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.LOCALITY_NAME, u"San Francisco"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"My Company"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, domain_name),
])
issuer = cryptography.x509.Name([ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"Amazon"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u"Server CA 1B"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, u"Amazon"),
])
cert = cryptography.x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
key.public_key()
).serial_number(
cryptography.x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=365)
).add_extension(
cryptography.x509.SubjectAlternativeName(sans),
critical=False,
).sign(key, hashes.SHA512(), default_backend())
cert_armored = cert.public_bytes(serialization.Encoding.PEM)
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION')
def validate_pk(self):
try:
self._key = serialization.load_pem_private_key(self.key, password=None, backend=default_backend())
if self._key.key_size > 2048:
AWSValidationException('The private key length is not supported. Only 1024-bit and 2048-bit are allowed.')
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The private key is not PEM-encoded or is not valid.')
def validate_certificate(self):
try:
self._cert = cryptography.x509.load_pem_x509_certificate(self.cert, default_backend())
now = datetime.datetime.utcnow()
if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate has expired, is not valid.')
if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate is not in effect yet, is not valid.')
# Extracting some common fields for ease of use
# Have to search through cert.subject for OIDs
self.common_name = self._cert.subject.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.')
def validate_chain(self):
try:
self._chain = []
for cert_armored in self.chain.split(b'-\n-'):
# Would leave encoded but Py2 does not have raw binary strings
cert_armored = cert_armored.decode()
# Fix missing -'s on split
cert_armored = re.sub(r'^----B', '-----B', cert_armored)
cert_armored = re.sub(r'E----$', 'E-----', cert_armored)
cert = cryptography.x509.load_pem_x509_certificate(cert_armored.encode(), default_backend())
self._chain.append(cert)
now = datetime.datetime.now()
if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate chain has expired, is not valid.')
if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate chain is not in effect yet, is not valid.')
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.')
def check(self):
# Basically, if the certificate is pending, and then checked again after 1 min
# It will appear as if its been validated
if self.type == 'AMAZON_ISSUED' and self.status == 'PENDING_VALIDATION' and \
(datetime.datetime.now() - self.created_at).total_seconds() > 60: # 1min
self.status = 'ISSUED'
def describe(self):
# 'RenewalSummary': {}, # Only when cert is amazon issued
if self._key.key_size == 1024:
key_algo = 'RSA_1024'
elif self._key.key_size == 2048:
key_algo = 'RSA_2048'
else:
key_algo = 'EC_prime256v1'
# Look for SANs
san_obj = self._cert.extensions.get_extension_for_oid(cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME)
sans = []
if san_obj is not None:
sans = [item.value for item in san_obj.value]
result = {
'Certificate': {
'CertificateArn': self.arn,
'DomainName': self.common_name,
'InUseBy': [],
'Issuer': self._cert.issuer.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value,
'KeyAlgorithm': key_algo,
'NotAfter': datetime_to_epoch(self._cert.not_valid_after),
'NotBefore': datetime_to_epoch(self._cert.not_valid_before),
'Serial': self._cert.serial,
'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''),
'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED.
'Subject': 'CN={0}'.format(self.common_name),
'SubjectAlternativeNames': sans,
'Type': self.type # One of IMPORTED, AMAZON_ISSUED
}
}
if self.type == 'IMPORTED':
result['Certificate']['ImportedAt'] = datetime_to_epoch(self.created_at)
else:
result['Certificate']['CreatedAt'] = datetime_to_epoch(self.created_at)
result['Certificate']['IssuedAt'] = datetime_to_epoch(self.created_at)
return result
def __str__(self):
return self.arn
def __repr__(self):
return '<Certificate>'
class AWSCertificateManagerBackend(BaseBackend):
def __init__(self, region):
super(AWSCertificateManagerBackend, self).__init__()
self.region = region
self._certificates = {}
self._idempotency_tokens = {}
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@staticmethod
def _arn_not_found(arn):
msg = 'Certificate with arn {0} not found in account {1}'.format(arn, DEFAULT_ACCOUNT_ID)
return AWSResourceNotFoundException(msg)
def _get_arn_from_idempotency_token(self, token):
"""
If token doesnt exist, return None, later it will be
set with an expiry and arn.
If token expiry has passed, delete entry and return None
Else return ARN
:param token: String token
:return: None or ARN
"""
now = datetime.datetime.now()
if token in self._idempotency_tokens:
if self._idempotency_tokens[token]['expires'] < now:
# Token has expired, new request
del self._idempotency_tokens[token]
return None
else:
return self._idempotency_tokens[token]['arn']
return None
def _set_idempotency_token_arn(self, token, arn):
self._idempotency_tokens[token] = {'arn': arn, 'expires': datetime.datetime.now() + datetime.timedelta(hours=1)}
def import_cert(self, certificate, private_key, chain=None, arn=None):
if arn is not None:
if arn not in self._certificates:
raise self._arn_not_found(arn)
else:
# Will reuse provided ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region, arn=arn)
else:
# Will generate a random ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region)
self._certificates[bundle.arn] = bundle
return bundle.arn
def get_certificates_list(self):
"""
Get list of certificates
:return: List of certificates
:rtype: list of CertBundle
"""
for arn in self._certificates.keys():
yield self.get_certificate(arn)
def get_certificate(self, arn):
if arn not in self._certificates:
raise self._arn_not_found(arn)
cert_bundle = self._certificates[arn]
cert_bundle.check()
return cert_bundle
def delete_certificate(self, arn):
if arn not in self._certificates:
raise self._arn_not_found(arn)
del self._certificates[arn]
def request_certificate(self, domain_name, domain_validation_options, idempotency_token, subject_alt_names):
if idempotency_token is not None:
arn = self._get_arn_from_idempotency_token(idempotency_token)
if arn is not None:
return arn
cert = CertBundle.generate_cert(domain_name, subject_alt_names)
if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert
return cert.arn
def add_tags_to_certificate(self, arn, tags):
# get_cert does arn check
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag['Key']
value = tag.get('Value', None)
cert_bundle.tags[key] = value
def remove_tags_from_certificate(self, arn, tags):
# get_cert does arn check
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag['Key']
value = tag.get('Value', None)
try:
# If value isnt provided, just delete key
if value is None:
del cert_bundle.tags[key]
# If value is provided, only delete if it matches what already exists
elif cert_bundle.tags[key] == value:
del cert_bundle.tags[key]
except KeyError:
pass
acm_backends = {}
for region, ec2_backend in ec2_backends.items():
acm_backends[region] = AWSCertificateManagerBackend(region)

224
moto/acm/responses.py Normal file
View File

@ -0,0 +1,224 @@
from __future__ import unicode_literals
import json
import base64
from moto.core.responses import BaseResponse
from .models import acm_backends, AWSError, AWSValidationException
class AWSCertificateManagerResponse(BaseResponse):
@property
def acm_backend(self):
"""
ACM Backend
:return: ACM Backend object
:rtype: moto.acm.models.AWSCertificateManagerBackend
"""
return acm_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param, default=None):
return self.request_params.get(param, default)
def add_tags_to_certificate(self):
arn = self._get_param('CertificateArn')
tags = self._get_param('Tags')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
self.acm_backend.add_tags_to_certificate(arn, tags)
except AWSError as err:
return err.response()
return ''
def delete_certificate(self):
arn = self._get_param('CertificateArn')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
self.acm_backend.delete_certificate(arn)
except AWSError as err:
return err.response()
return ''
def describe_certificate(self):
arn = self._get_param('CertificateArn')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
return json.dumps(cert_bundle.describe())
def get_certificate(self):
arn = self._get_param('CertificateArn')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
result = {
'Certificate': cert_bundle.cert.decode(),
'CertificateChain': cert_bundle.chain.decode()
}
return json.dumps(result)
def import_certificate(self):
"""
Returns errors on:
Certificate, PrivateKey or Chain not being properly formatted
Arn not existing if its provided
PrivateKey size > 2048
Certificate expired or is not yet in effect
Does not return errors on:
Checking Certificate is legit, or a selfsigned chain is provided
:return: str(JSON) for response
"""
certificate = self._get_param('Certificate')
private_key = self._get_param('PrivateKey')
chain = self._get_param('CertificateChain') # Optional
current_arn = self._get_param('CertificateArn') # Optional
# Simple parameter decoding. Rather do it here as its a data transport decision not part of the
# actual data
try:
certificate = base64.standard_b64decode(certificate)
except:
return AWSValidationException('The certificate is not PEM-encoded or is not valid.').response()
try:
private_key = base64.standard_b64decode(private_key)
except:
return AWSValidationException('The private key is not PEM-encoded or is not valid.').response()
if chain is not None:
try:
chain = base64.standard_b64decode(chain)
except:
return AWSValidationException('The certificate chain is not PEM-encoded or is not valid.').response()
try:
arn = self.acm_backend.import_cert(certificate, private_key, chain=chain, arn=current_arn)
except AWSError as err:
return err.response()
return json.dumps({'CertificateArn': arn})
def list_certificates(self):
certs = []
for cert_bundle in self.acm_backend.get_certificates_list():
certs.append({
'CertificateArn': cert_bundle.arn,
'DomainName': cert_bundle.common_name
})
result = {'CertificateSummaryList': certs}
return json.dumps(result)
def list_tags_for_certificate(self):
arn = self._get_param('CertificateArn')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return {'__type': 'MissingParameter', 'message': msg}, dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
result = {'Tags': []}
# Tag "objects" can not contain the Value part
for key, value in cert_bundle.tags.items():
tag_dict = {'Key': key}
if value is not None:
tag_dict['Value'] = value
result['Tags'].append(tag_dict)
return json.dumps(result)
def remove_tags_from_certificate(self):
arn = self._get_param('CertificateArn')
tags = self._get_param('Tags')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
self.acm_backend.remove_tags_from_certificate(arn, tags)
except AWSError as err:
return err.response()
return ''
def request_certificate(self):
domain_name = self._get_param('DomainName')
domain_validation_options = self._get_param('DomainValidationOptions') # is ignored atm
idempotency_token = self._get_param('IdempotencyToken')
subject_alt_names = self._get_param('SubjectAlternativeNames')
if subject_alt_names is not None and len(subject_alt_names) > 10:
# There is initial AWS limit of 10
msg = 'An ACM limit has been exceeded. Need to request SAN limit to be raised'
return json.dumps({'__type': 'LimitExceededException', 'message': msg}), dict(status=400)
try:
arn = self.acm_backend.request_certificate(domain_name, domain_validation_options, idempotency_token, subject_alt_names)
except AWSError as err:
return err.response()
return json.dumps({'CertificateArn': arn})
def resend_validation_email(self):
arn = self._get_param('CertificateArn')
domain = self._get_param('Domain')
# ValidationDomain not used yet.
# Contains domain which is equal to or a subset of Domain
# that AWS will send validation emails to
# https://docs.aws.amazon.com/acm/latest/APIReference/API_ResendValidationEmail.html
# validation_domain = self._get_param('ValidationDomain')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
if cert_bundle.common_name != domain:
msg = 'Parameter Domain does not match certificate domain'
_type = 'InvalidDomainValidationOptionsException'
return json.dumps({'__type': _type, 'message': msg}), dict(status=400)
except AWSError as err:
return err.response()
return ''

10
moto/acm/urls.py Normal file
View File

@ -0,0 +1,10 @@
from __future__ import unicode_literals
from .responses import AWSCertificateManagerResponse
url_bases = [
"https?://acm.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': AWSCertificateManagerResponse.dispatch,
}

7
moto/acm/utils.py Normal file
View File

@ -0,0 +1,7 @@
import uuid
def make_arn_for_certificate(account_id, region_name):
# Example
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(region_name, account_id, uuid.uuid4())

View File

@ -0,0 +1,14 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
class AutoscalingClientError(RESTError):
code = 500
class ResourceContentionError(AutoscalingClientError):
def __init__(self):
super(ResourceContentionError, self).__init__(
"ResourceContentionError",
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).")

View File

@ -4,21 +4,26 @@ from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
from moto.elb import elb_backends from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
from moto.elb.exceptions import LoadBalancerNotFoundError from moto.elb.exceptions import LoadBalancerNotFoundError
from .exceptions import (
ResourceContentionError,
)
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown # http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
DEFAULT_COOLDOWN = 300 DEFAULT_COOLDOWN = 300
ASG_NAME_TAG = "aws:autoscaling:groupName"
class InstanceState(object): class InstanceState(object):
def __init__(self, instance, lifecycle_state="InService", health_status="Healthy"):
def __init__(self, instance, lifecycle_state="InService"):
self.instance = instance self.instance = instance
self.lifecycle_state = lifecycle_state self.lifecycle_state = lifecycle_state
self.health_status = health_status
class FakeScalingPolicy(BaseModel): class FakeScalingPolicy(BaseModel):
def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment, def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment,
cooldown, autoscaling_backend): cooldown, autoscaling_backend):
self.name = name self.name = name
@ -45,7 +50,6 @@ class FakeScalingPolicy(BaseModel):
class FakeLaunchConfiguration(BaseModel): class FakeLaunchConfiguration(BaseModel):
def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data, def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data,
instance_type, instance_monitoring, instance_profile_name, instance_type, instance_monitoring, instance_profile_name,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict): spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict):
@ -144,11 +148,10 @@ class FakeLaunchConfiguration(BaseModel):
class FakeAutoScalingGroup(BaseModel): class FakeAutoScalingGroup(BaseModel):
def __init__(self, name, availability_zones, desired_capacity, max_size, def __init__(self, name, availability_zones, desired_capacity, max_size,
min_size, launch_config_name, vpc_zone_identifier, min_size, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, health_check_type, default_cooldown, health_check_period, health_check_type,
load_balancers, placement_group, termination_policies, load_balancers, target_group_arns, placement_group, termination_policies,
autoscaling_backend, tags): autoscaling_backend, tags):
self.autoscaling_backend = autoscaling_backend self.autoscaling_backend = autoscaling_backend
self.name = name self.name = name
@ -165,12 +168,13 @@ class FakeAutoScalingGroup(BaseModel):
self.health_check_period = health_check_period self.health_check_period = health_check_period
self.health_check_type = health_check_type if health_check_type else "EC2" self.health_check_type = health_check_type if health_check_type else "EC2"
self.load_balancers = load_balancers self.load_balancers = load_balancers
self.target_group_arns = target_group_arns
self.placement_group = placement_group self.placement_group = placement_group
self.termination_policies = termination_policies self.termination_policies = termination_policies
self.instance_states = [] self.instance_states = []
self.set_desired_capacity(desired_capacity)
self.tags = tags if tags else [] self.tags = tags if tags else []
self.set_desired_capacity(desired_capacity)
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -178,6 +182,7 @@ class FakeAutoScalingGroup(BaseModel):
launch_config_name = properties.get("LaunchConfigurationName") launch_config_name = properties.get("LaunchConfigurationName")
load_balancer_names = properties.get("LoadBalancerNames", []) load_balancer_names = properties.get("LoadBalancerNames", [])
target_group_arns = properties.get("TargetGroupARNs", [])
backend = autoscaling_backends[region_name] backend = autoscaling_backends[region_name]
group = backend.create_autoscaling_group( group = backend.create_autoscaling_group(
@ -193,6 +198,7 @@ class FakeAutoScalingGroup(BaseModel):
health_check_period=properties.get("HealthCheckGracePeriod"), health_check_period=properties.get("HealthCheckGracePeriod"),
health_check_type=properties.get("HealthCheckType"), health_check_type=properties.get("HealthCheckType"),
load_balancers=load_balancer_names, load_balancers=load_balancer_names,
target_group_arns=target_group_arns,
placement_group=None, placement_group=None,
termination_policies=properties.get("TerminationPolicies", []), termination_policies=properties.get("TerminationPolicies", []),
tags=properties.get("Tags", []), tags=properties.get("Tags", []),
@ -223,7 +229,7 @@ class FakeAutoScalingGroup(BaseModel):
def update(self, availability_zones, desired_capacity, max_size, min_size, def update(self, availability_zones, desired_capacity, max_size, min_size,
launch_config_name, vpc_zone_identifier, default_cooldown, launch_config_name, vpc_zone_identifier, default_cooldown,
health_check_period, health_check_type, load_balancers, health_check_period, health_check_type,
placement_group, termination_policies): placement_group, termination_policies):
if availability_zones: if availability_zones:
self.availability_zones = availability_zones self.availability_zones = availability_zones
@ -259,18 +265,10 @@ class FakeAutoScalingGroup(BaseModel):
if self.desired_capacity > curr_instance_count: if self.desired_capacity > curr_instance_count:
# Need more instances # Need more instances
count_needed = int(self.desired_capacity) - \ count_needed = int(self.desired_capacity) - int(curr_instance_count)
int(curr_instance_count)
reservation = self.autoscaling_backend.ec2_backend.add_instances( propagated_tags = self.get_propagated_tags()
self.launch_config.image_id, self.replace_autoscaling_group_instances(count_needed, propagated_tags)
count_needed,
self.launch_config.user_data,
self.launch_config.security_groups,
instance_type=self.launch_config.instance_type,
)
for instance in reservation.instances:
instance.autoscaling_group = self
self.instance_states.append(InstanceState(instance))
else: else:
# Need to remove some instances # Need to remove some instances
count_to_remove = curr_instance_count - self.desired_capacity count_to_remove = curr_instance_count - self.desired_capacity
@ -281,21 +279,51 @@ class FakeAutoScalingGroup(BaseModel):
instance_ids_to_remove) instance_ids_to_remove)
self.instance_states = self.instance_states[count_to_remove:] self.instance_states = self.instance_states[count_to_remove:]
def get_propagated_tags(self):
propagated_tags = {}
for tag in self.tags:
# boto uses 'propagate_at_launch
# boto3 and cloudformation use PropagateAtLaunch
if 'propagate_at_launch' in tag and tag['propagate_at_launch'] == 'true':
propagated_tags[tag['key']] = tag['value']
if 'PropagateAtLaunch' in tag and tag['PropagateAtLaunch']:
propagated_tags[tag['Key']] = tag['Value']
return propagated_tags
def replace_autoscaling_group_instances(self, count_needed, propagated_tags):
propagated_tags[ASG_NAME_TAG] = self.name
reservation = self.autoscaling_backend.ec2_backend.add_instances(
self.launch_config.image_id,
count_needed,
self.launch_config.user_data,
self.launch_config.security_groups,
instance_type=self.launch_config.instance_type,
tags={'instance': propagated_tags}
)
for instance in reservation.instances:
instance.autoscaling_group = self
self.instance_states.append(InstanceState(instance))
def append_target_groups(self, target_group_arns):
append = [x for x in target_group_arns if x not in self.target_group_arns]
self.target_group_arns.extend(append)
class AutoScalingBackend(BaseBackend): class AutoScalingBackend(BaseBackend):
def __init__(self, ec2_backend, elb_backend, elbv2_backend):
def __init__(self, ec2_backend, elb_backend):
self.autoscaling_groups = OrderedDict() self.autoscaling_groups = OrderedDict()
self.launch_configurations = OrderedDict() self.launch_configurations = OrderedDict()
self.policies = {} self.policies = {}
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.elb_backend = elb_backend self.elb_backend = elb_backend
self.elbv2_backend = elbv2_backend
def reset(self): def reset(self):
ec2_backend = self.ec2_backend ec2_backend = self.ec2_backend
elb_backend = self.elb_backend elb_backend = self.elb_backend
elbv2_backend = self.elbv2_backend
self.__dict__ = {} self.__dict__ = {}
self.__init__(ec2_backend, elb_backend) self.__init__(ec2_backend, elb_backend, elbv2_backend)
def create_launch_configuration(self, name, image_id, key_name, kernel_id, ramdisk_id, def create_launch_configuration(self, name, image_id, key_name, kernel_id, ramdisk_id,
security_groups, user_data, instance_type, security_groups, user_data, instance_type,
@ -335,7 +363,8 @@ class AutoScalingBackend(BaseBackend):
launch_config_name, vpc_zone_identifier, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, default_cooldown, health_check_period,
health_check_type, load_balancers, health_check_type, load_balancers,
placement_group, termination_policies, tags): target_group_arns, placement_group,
termination_policies, tags):
def make_int(value): def make_int(value):
return int(value) if value is not None else value return int(value) if value is not None else value
@ -361,6 +390,7 @@ class AutoScalingBackend(BaseBackend):
health_check_period=health_check_period, health_check_period=health_check_period,
health_check_type=health_check_type, health_check_type=health_check_type,
load_balancers=load_balancers, load_balancers=load_balancers,
target_group_arns=target_group_arns,
placement_group=placement_group, placement_group=placement_group,
termination_policies=termination_policies, termination_policies=termination_policies,
autoscaling_backend=self, autoscaling_backend=self,
@ -369,19 +399,20 @@ class AutoScalingBackend(BaseBackend):
self.autoscaling_groups[name] = group self.autoscaling_groups[name] = group
self.update_attached_elbs(group.name) self.update_attached_elbs(group.name)
self.update_attached_target_groups(group.name)
return group return group
def update_autoscaling_group(self, name, availability_zones, def update_autoscaling_group(self, name, availability_zones,
desired_capacity, max_size, min_size, desired_capacity, max_size, min_size,
launch_config_name, vpc_zone_identifier, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, default_cooldown, health_check_period,
health_check_type, load_balancers, health_check_type, placement_group,
placement_group, termination_policies): termination_policies):
group = self.autoscaling_groups[name] group = self.autoscaling_groups[name]
group.update(availability_zones, desired_capacity, max_size, group.update(availability_zones, desired_capacity, max_size,
min_size, launch_config_name, vpc_zone_identifier, min_size, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, health_check_type, default_cooldown, health_check_period, health_check_type,
load_balancers, placement_group, termination_policies) placement_group, termination_policies)
return group return group
def describe_autoscaling_groups(self, names): def describe_autoscaling_groups(self, names):
@ -401,6 +432,46 @@ class AutoScalingBackend(BaseBackend):
instance_states.extend(group.instance_states) instance_states.extend(group.instance_states)
return instance_states return instance_states
def attach_instances(self, group_name, instance_ids):
group = self.autoscaling_groups[group_name]
original_size = len(group.instance_states)
if (original_size + len(instance_ids)) > group.max_size:
raise ResourceContentionError
else:
group.desired_capacity = original_size + len(instance_ids)
new_instances = [InstanceState(self.ec2_backend.get_instance(x)) for x in instance_ids]
for instance in new_instances:
self.ec2_backend.create_tags([instance.instance.id], {ASG_NAME_TAG: group.name})
group.instance_states.extend(new_instances)
self.update_attached_elbs(group.name)
def set_instance_health(self, instance_id, health_status, should_respect_grace_period):
instance = self.ec2_backend.get_instance(instance_id)
instance_state = next(instance_state for group in self.autoscaling_groups.values()
for instance_state in group.instance_states if instance_state.instance.id == instance.id)
instance_state.health_status = health_status
def detach_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name]
original_size = len(group.instance_states)
detached_instances = [x for x in group.instance_states if x.instance.id in instance_ids]
for instance in detached_instances:
self.ec2_backend.delete_tags([instance.instance.id], {ASG_NAME_TAG: group.name})
new_instance_state = [x for x in group.instance_states if x.instance.id not in instance_ids]
group.instance_states = new_instance_state
if should_decrement:
group.desired_capacity = original_size - len(instance_ids)
else:
count_needed = len(instance_ids)
group.replace_autoscaling_group_instances(count_needed, group.get_propagated_tags())
self.update_attached_elbs(group_name)
return detached_instances
def set_desired_capacity(self, group_name, desired_capacity): def set_desired_capacity(self, group_name, desired_capacity):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group.set_desired_capacity(desired_capacity) group.set_desired_capacity(desired_capacity)
@ -453,6 +524,10 @@ class AutoScalingBackend(BaseBackend):
group_instance_ids = set( group_instance_ids = set(
state.instance.id for state in group.instance_states) state.instance.id for state in group.instance_states)
# skip this if group.load_balancers is empty
# otherwise elb_backend.describe_load_balancers returns all available load balancers
if not group.load_balancers:
return
try: try:
elbs = self.elb_backend.describe_load_balancers( elbs = self.elb_backend.describe_load_balancers(
names=group.load_balancers) names=group.load_balancers)
@ -467,8 +542,25 @@ class AutoScalingBackend(BaseBackend):
self.elb_backend.deregister_instances( self.elb_backend.deregister_instances(
elb.name, elb_instace_ids - group_instance_ids) elb.name, elb_instace_ids - group_instance_ids)
def create_or_update_tags(self, tags): def update_attached_target_groups(self, group_name):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(
state.instance.id for state in group.instance_states)
# no action necessary if target_group_arns is empty
if not group.target_group_arns:
return
target_groups = self.elbv2_backend.describe_target_groups(
target_group_arns=group.target_group_arns,
load_balancer_arn=None,
names=None)
for target_group in target_groups:
asg_targets = [{'id': x, 'port': target_group.port} for x in group_instance_ids]
self.elbv2_backend.register_targets(target_group.arn, (asg_targets))
def create_or_update_tags(self, tags):
for tag in tags: for tag in tags:
group_name = tag["resource_id"] group_name = tag["resource_id"]
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
@ -488,8 +580,42 @@ class AutoScalingBackend(BaseBackend):
group.tags = new_tags group.tags = new_tags
def attach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name]
group.load_balancers.extend(
[x for x in load_balancer_names if x not in group.load_balancers])
self.update_attached_elbs(group_name)
def describe_load_balancers(self, group_name):
return self.autoscaling_groups[group_name].load_balancers
def detach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(
state.instance.id for state in group.instance_states)
elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
for elb in elbs:
self.elb_backend.deregister_instances(
elb.name, group_instance_ids)
group.load_balancers = [x for x in group.load_balancers if x not in load_balancer_names]
def attach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name]
group.append_target_groups(target_group_arns)
self.update_attached_target_groups(group_name)
def describe_load_balancer_target_groups(self, group_name):
return self.autoscaling_groups[group_name].target_group_arns
def detach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name]
group.target_group_arns = [x for x in group.target_group_arns if x not in target_group_arns]
for target_group in target_group_arns:
asg_targets = [{'id': x.instance.id} for x in group.instance_states]
self.elbv2_backend.deregister_targets(target_group, (asg_targets))
autoscaling_backends = {} autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():
autoscaling_backends[region] = AutoScalingBackend( autoscaling_backends[region] = AutoScalingBackend(
ec2_backend, elb_backends[region]) ec2_backend, elb_backends[region], elbv2_backends[region])

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import amz_crc32, amzn_request_id
from .models import autoscaling_backends from .models import autoscaling_backends
@ -79,6 +80,7 @@ class AutoScalingResponse(BaseResponse):
health_check_period=self._get_int_param('HealthCheckGracePeriod'), health_check_period=self._get_int_param('HealthCheckGracePeriod'),
health_check_type=self._get_param('HealthCheckType'), health_check_type=self._get_param('HealthCheckType'),
load_balancers=self._get_multi_param('LoadBalancerNames.member'), load_balancers=self._get_multi_param('LoadBalancerNames.member'),
target_group_arns=self._get_multi_param('TargetGroupARNs.member'),
placement_group=self._get_param('PlacementGroup'), placement_group=self._get_param('PlacementGroup'),
termination_policies=self._get_multi_param( termination_policies=self._get_multi_param(
'TerminationPolicies.member'), 'TerminationPolicies.member'),
@ -87,6 +89,74 @@ class AutoScalingResponse(BaseResponse):
template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render() return template.render()
@amz_crc32
@amzn_request_id
def attach_instances(self):
group_name = self._get_param('AutoScalingGroupName')
instance_ids = self._get_multi_param('InstanceIds.member')
self.autoscaling_backend.attach_instances(
group_name, instance_ids)
template = self.response_template(ATTACH_INSTANCES_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def set_instance_health(self):
instance_id = self._get_param('InstanceId')
health_status = self._get_param("HealthStatus")
if health_status not in ['Healthy', 'Unhealthy']:
raise ValueError('Valid instance health states are: [Healthy, Unhealthy]')
should_respect_grace_period = self._get_param("ShouldRespectGracePeriod")
self.autoscaling_backend.set_instance_health(instance_id, health_status, should_respect_grace_period)
template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def detach_instances(self):
group_name = self._get_param('AutoScalingGroupName')
instance_ids = self._get_multi_param('InstanceIds.member')
should_decrement_string = self._get_param('ShouldDecrementDesiredCapacity')
if should_decrement_string == 'true':
should_decrement = True
else:
should_decrement = False
detached_instances = self.autoscaling_backend.detach_instances(
group_name, instance_ids, should_decrement)
template = self.response_template(DETACH_INSTANCES_TEMPLATE)
return template.render(detached_instances=detached_instances)
@amz_crc32
@amzn_request_id
def attach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
target_group_arns = self._get_multi_param('TargetGroupARNs.member')
self.autoscaling_backend.attach_load_balancer_target_groups(
group_name, target_group_arns)
template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def describe_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups(
group_name)
template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS)
return template.render(target_group_arns=target_group_arns)
@amz_crc32
@amzn_request_id
def detach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
target_group_arns = self._get_multi_param('TargetGroupARNs.member')
self.autoscaling_backend.detach_load_balancer_target_groups(
group_name, target_group_arns)
template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render()
def describe_auto_scaling_groups(self): def describe_auto_scaling_groups(self):
names = self._get_multi_param("AutoScalingGroupNames.member") names = self._get_multi_param("AutoScalingGroupNames.member")
token = self._get_param("NextToken") token = self._get_param("NextToken")
@ -119,7 +189,6 @@ class AutoScalingResponse(BaseResponse):
default_cooldown=self._get_int_param('DefaultCooldown'), default_cooldown=self._get_int_param('DefaultCooldown'),
health_check_period=self._get_int_param('HealthCheckGracePeriod'), health_check_period=self._get_int_param('HealthCheckGracePeriod'),
health_check_type=self._get_param('HealthCheckType'), health_check_type=self._get_param('HealthCheckType'),
load_balancers=self._get_multi_param('LoadBalancerNames.member'),
placement_group=self._get_param('PlacementGroup'), placement_group=self._get_param('PlacementGroup'),
termination_policies=self._get_multi_param( termination_policies=self._get_multi_param(
'TerminationPolicies.member'), 'TerminationPolicies.member'),
@ -186,6 +255,34 @@ class AutoScalingResponse(BaseResponse):
template = self.response_template(EXECUTE_POLICY_TEMPLATE) template = self.response_template(EXECUTE_POLICY_TEMPLATE)
return template.render() return template.render()
@amz_crc32
@amzn_request_id
def attach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.attach_load_balancers(
group_name, load_balancer_names)
template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def describe_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
load_balancers = self.autoscaling_backend.describe_load_balancers(group_name)
template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE)
return template.render(load_balancers=load_balancers)
@amz_crc32
@amzn_request_id
def detach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.detach_load_balancers(
group_name, load_balancer_names)
template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE)
return template.render()
CREATE_LAUNCH_CONFIGURATION_TEMPLATE = """<CreateLaunchConfigurationResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/"> CREATE_LAUNCH_CONFIGURATION_TEMPLATE = """<CreateLaunchConfigurationResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<ResponseMetadata> <ResponseMetadata>
@ -284,6 +381,72 @@ CREATE_AUTOSCALING_GROUP_TEMPLATE = """<CreateAutoScalingGroupResponse xmlns="ht
</ResponseMetadata> </ResponseMetadata>
</CreateAutoScalingGroupResponse>""" </CreateAutoScalingGroupResponse>"""
ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE = """<AttachLoadBalancerTargetGroupsResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<AttachLoadBalancerTargetGroupsResult>
</AttachLoadBalancerTargetGroupsResult>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</AttachLoadBalancerTargetGroupsResponse>"""
ATTACH_INSTANCES_TEMPLATE = """<AttachInstancesResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<AttachInstancesResult>
</AttachInstancesResult>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</AttachInstancesResponse>"""
DESCRIBE_LOAD_BALANCER_TARGET_GROUPS = """<DescribeLoadBalancerTargetGroupsResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<DescribeLoadBalancerTargetGroupsResult>
<LoadBalancerTargetGroups>
{% for arn in target_group_arns %}
<member>
<LoadBalancerTargetGroupARN>{{ arn }}</LoadBalancerTargetGroupARN>
<State>Added</State>
</member>
{% endfor %}
</LoadBalancerTargetGroups>
</DescribeLoadBalancerTargetGroupsResult>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</DescribeLoadBalancerTargetGroupsResponse>"""
DETACH_INSTANCES_TEMPLATE = """<DetachInstancesResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<DetachInstancesResult>
<Activities>
{% for instance in detached_instances %}
<member>
<ActivityId>5091cb52-547a-47ce-a236-c9ccbc2cb2c9EXAMPLE</ActivityId>
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
<Cause>
At 2017-10-15T15:55:21Z instance {{ instance.instance.id }} was detached in response to a user request.
</Cause>
<Description>Detaching EC2 instance: {{ instance.instance.id }}</Description>
<StartTime>2017-10-15T15:55:21Z</StartTime>
<EndTime>2017-10-15T15:55:21Z</EndTime>
<StatusCode>InProgress</StatusCode>
<StatusMessage>InProgress</StatusMessage>
<Progress>50</Progress>
<Details>details</Details>
</member>
{% endfor %}
</Activities>
</DetachInstancesResult>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</DetachInstancesResponse>"""
DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE = """<DetachLoadBalancerTargetGroupsResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<DetachLoadBalancerTargetGroupsResult>
</DetachLoadBalancerTargetGroupsResult>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</DetachLoadBalancerTargetGroupsResponse>"""
DESCRIBE_AUTOSCALING_GROUPS_TEMPLATE = """<DescribeAutoScalingGroupsResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/"> DESCRIBE_AUTOSCALING_GROUPS_TEMPLATE = """<DescribeAutoScalingGroupsResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<DescribeAutoScalingGroupsResult> <DescribeAutoScalingGroupsResult>
<AutoScalingGroups> <AutoScalingGroups>
@ -309,7 +472,7 @@ DESCRIBE_AUTOSCALING_GROUPS_TEMPLATE = """<DescribeAutoScalingGroupsResponse xml
<Instances> <Instances>
{% for instance_state in group.instance_states %} {% for instance_state in group.instance_states %}
<member> <member>
<HealthStatus>HEALTHY</HealthStatus> <HealthStatus>{{ instance_state.health_status }}</HealthStatus>
<AvailabilityZone>us-east-1e</AvailabilityZone> <AvailabilityZone>us-east-1e</AvailabilityZone>
<InstanceId>{{ instance_state.instance.id }}</InstanceId> <InstanceId>{{ instance_state.instance.id }}</InstanceId>
<LaunchConfigurationName>{{ group.launch_config_name }}</LaunchConfigurationName> <LaunchConfigurationName>{{ group.launch_config_name }}</LaunchConfigurationName>
@ -384,7 +547,7 @@ DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE = """<DescribeAutoScalingInstancesRespon
<AutoScalingInstances> <AutoScalingInstances>
{% for instance_state in instance_states %} {% for instance_state in instance_states %}
<member> <member>
<HealthStatus>HEALTHY</HealthStatus> <HealthStatus>{{ instance_state.health_status }}</HealthStatus>
<AutoScalingGroupName>{{ instance_state.instance.autoscaling_group.name }}</AutoScalingGroupName> <AutoScalingGroupName>{{ instance_state.instance.autoscaling_group.name }}</AutoScalingGroupName>
<AvailabilityZone>us-east-1e</AvailabilityZone> <AvailabilityZone>us-east-1e</AvailabilityZone>
<InstanceId>{{ instance_state.instance.id }}</InstanceId> <InstanceId>{{ instance_state.instance.id }}</InstanceId>
@ -450,3 +613,40 @@ DELETE_POLICY_TEMPLATE = """<DeleteScalingPolicyResponse xmlns="http://autoscali
<RequestId>70a76d42-9665-11e2-9fdf-211deEXAMPLE</RequestId> <RequestId>70a76d42-9665-11e2-9fdf-211deEXAMPLE</RequestId>
</ResponseMetadata> </ResponseMetadata>
</DeleteScalingPolicyResponse>""" </DeleteScalingPolicyResponse>"""
ATTACH_LOAD_BALANCERS_TEMPLATE = """<AttachLoadBalancersResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<AttachLoadBalancersResult></AttachLoadBalancersResult>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</AttachLoadBalancersResponse>"""
DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<DescribeLoadBalancersResult>
<LoadBalancers>
{% for load_balancer in load_balancers %}
<member>
<LoadBalancerName>{{ load_balancer }}</LoadBalancerName>
<State>Added</State>
</member>
{% endfor %}
</LoadBalancers>
</DescribeLoadBalancersResult>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</DescribeLoadBalancersResponse>"""
DETACH_LOAD_BALANCERS_TEMPLATE = """<DetachLoadBalancersResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<DetachLoadBalancersResult></DetachLoadBalancersResult>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</DetachLoadBalancersResponse>"""
SET_INSTANCE_HEALTH_TEMPLATE = """<SetInstanceHealthResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<SetInstanceHealthResponse></SetInstanceHealthResponse>
<ResponseMetadata>
<RequestId>{{ requestid }}</RequestId>
</ResponseMetadata>
</SetInstanceHealthResponse>"""

View File

@ -1,33 +1,151 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import base64 import base64
from collections import defaultdict
import datetime import datetime
import docker.errors
import hashlib import hashlib
import io import io
import logging
import os
import json import json
import sys import re
import zipfile import zipfile
import uuid
try: import functools
from StringIO import StringIO import tarfile
except: import calendar
from io import StringIO import threading
import traceback
import requests.adapters
import boto.awslambda import boto.awslambda
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time_millis
from moto.s3.models import s3_backend from moto.s3.models import s3_backend
from moto.s3.exceptions import MissingBucket from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings
logger = logging.getLogger(__name__)
try:
from tempfile import TemporaryDirectory
except ImportError:
from backports.tempfile import TemporaryDirectory
_stderr_regex = re.compile(r'START|END|REPORT RequestId: .*')
_orig_adapter_send = requests.adapters.HTTPAdapter.send
def zip2tar(zip_bytes):
with TemporaryDirectory() as td:
tarname = os.path.join(td, 'data.tar')
timeshift = int((datetime.datetime.now() -
datetime.datetime.utcnow()).total_seconds())
with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zipf, \
tarfile.TarFile(tarname, 'w') as tarf:
for zipinfo in zipf.infolist():
if zipinfo.filename[-1] == '/': # is_dir() is py3.6+
continue
tarinfo = tarfile.TarInfo(name=zipinfo.filename)
tarinfo.size = zipinfo.file_size
tarinfo.mtime = calendar.timegm(zipinfo.date_time) - timeshift
infile = zipf.open(zipinfo.filename)
tarf.addfile(tarinfo, infile)
with open(tarname, 'rb') as f:
tar_data = f.read()
return tar_data
class _VolumeRefCount:
__slots__ = "refcount", "volume"
def __init__(self, refcount, volume):
self.refcount = refcount
self.volume = volume
class _DockerDataVolumeContext:
_data_vol_map = defaultdict(lambda: _VolumeRefCount(0, None)) # {sha256: _VolumeRefCount}
_lock = threading.Lock()
def __init__(self, lambda_func):
self._lambda_func = lambda_func
self._vol_ref = None
@property
def name(self):
return self._vol_ref.volume.name
def __enter__(self):
# See if volume is already known
with self.__class__._lock:
self._vol_ref = self.__class__._data_vol_map[self._lambda_func.code_sha_256]
self._vol_ref.refcount += 1
if self._vol_ref.refcount > 1:
return self
# See if the volume already exists
for vol in self._lambda_func.docker_client.volumes.list():
if vol.name == self._lambda_func.code_sha_256:
self._vol_ref.volume = vol
return self
# It doesn't exist so we need to create it
self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(self._lambda_func.code_sha_256)
container = self._lambda_func.docker_client.containers.run('alpine', 'sleep 100', volumes={self.name: '/tmp/data'}, detach=True)
try:
tar_bytes = zip2tar(self._lambda_func.code_bytes)
container.put_archive('/tmp/data', tar_bytes)
finally:
container.remove(force=True)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
with self.__class__._lock:
self._vol_ref.refcount -= 1
if self._vol_ref.refcount == 0:
try:
self._vol_ref.volume.remove()
except docker.errors.APIError as e:
if e.status_code != 409:
raise
raise # multiple processes trying to use same volume?
class LambdaFunction(BaseModel): class LambdaFunction(BaseModel):
def __init__(self, spec, region, validate_s3=True):
def __init__(self, spec):
# required # required
self.region = region
self.code = spec['Code'] self.code = spec['Code']
self.function_name = spec['FunctionName'] self.function_name = spec['FunctionName']
self.handler = spec['Handler'] self.handler = spec['Handler']
self.role = spec['Role'] self.role = spec['Role']
self.run_time = spec['Runtime'] self.run_time = spec['Runtime']
self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get('Environment', {}).get('Variables', {})
self.docker_client = docker.from_env()
self.policy = ""
# Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked
if requests.adapters.HTTPAdapter.send != _orig_adapter_send:
_orig_get_adapter = self.docker_client.api.get_adapter
def replace_adapter_send(*args, **kwargs):
adapter = _orig_get_adapter(*args, **kwargs)
if isinstance(adapter, requests.adapters.HTTPAdapter):
adapter.send = functools.partial(_orig_adapter_send, adapter)
return adapter
self.docker_client.api.get_adapter = replace_adapter_send
# optional # optional
self.description = spec.get('Description', '') self.description = spec.get('Description', '')
@ -35,13 +153,18 @@ class LambdaFunction(BaseModel):
self.publish = spec.get('Publish', False) # this is ignored currently self.publish = spec.get('Publish', False) # this is ignored currently
self.timeout = spec.get('Timeout', 3) self.timeout = spec.get('Timeout', 3)
self.logs_group_name = '/aws/lambda/{}'.format(self.function_name)
self.logs_backend.ensure_log_group(self.logs_group_name, [])
# this isn't finished yet. it needs to find out the VpcId value # this isn't finished yet. it needs to find out the VpcId value
self._vpc_config = spec.get( self._vpc_config = spec.get(
'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []}) 'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []})
# auto-generated # auto-generated
self.version = '$LATEST' self.version = '$LATEST'
self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') self.last_modified = datetime.datetime.utcnow().strftime(
'%Y-%m-%d %H:%M:%S')
if 'ZipFile' in self.code: if 'ZipFile' in self.code:
# more hackery to handle unicode/bytes/str in python3 and python2 - # more hackery to handle unicode/bytes/str in python3 and python2 -
# argh! # argh!
@ -51,33 +174,39 @@ class LambdaFunction(BaseModel):
except Exception: except Exception:
to_unzip_code = base64.b64decode(self.code['ZipFile']) to_unzip_code = base64.b64decode(self.code['ZipFile'])
zbuffer = io.BytesIO() self.code_bytes = to_unzip_code
zbuffer.write(to_unzip_code)
zip_file = zipfile.ZipFile(zbuffer, 'r', zipfile.ZIP_DEFLATED)
self.code = zip_file.read("".join(zip_file.namelist()))
self.code_size = len(to_unzip_code) self.code_size = len(to_unzip_code)
self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest() self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest()
# TODO: we should be putting this in a lambda bucket
self.code['UUID'] = str(uuid.uuid4())
self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID'])
else: else:
# validate s3 bucket # validate s3 bucket and key
key = None
try: try:
# FIXME: does not validate bucket region # FIXME: does not validate bucket region
key = s3_backend.get_key( key = s3_backend.get_key(
self.code['S3Bucket'], self.code['S3Key']) self.code['S3Bucket'], self.code['S3Key'])
except MissingBucket: except MissingBucket:
raise ValueError( if do_validate_s3():
"InvalidParameterValueException", raise ValueError(
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist") "InvalidParameterValueException",
else: "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist")
# validate s3 key except MissingKey:
if key is None: if do_validate_s3():
raise ValueError( raise ValueError(
"InvalidParameterValueException", "InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.") "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.")
else: if key:
self.code_size = key.size self.code_bytes = key.value
self.code_sha_256 = hashlib.sha256(key.value).hexdigest() self.code_size = key.size
self.function_arn = 'arn:aws:lambda:123456789012:function:{0}'.format( self.code_sha_256 = hashlib.sha256(key.value).hexdigest()
self.function_name)
self.function_arn = 'arn:aws:lambda:{}:123456789012:function:{}'.format(
self.region, self.function_name)
self.tags = dict()
@property @property
def vpc_config(self): def vpc_config(self):
@ -90,7 +219,7 @@ class LambdaFunction(BaseModel):
return json.dumps(self.get_configuration()) return json.dumps(self.get_configuration())
def get_configuration(self): def get_configuration(self):
return { config = {
"CodeSha256": self.code_sha_256, "CodeSha256": self.code_sha_256,
"CodeSize": self.code_size, "CodeSize": self.code_size,
"Description": self.description, "Description": self.description,
@ -106,65 +235,105 @@ class LambdaFunction(BaseModel):
"VpcConfig": self.vpc_config, "VpcConfig": self.vpc_config,
} }
if self.environment_vars:
config['Environment'] = {
'Variables': self.environment_vars
}
return config
def get_code(self): def get_code(self):
return { return {
"Code": { "Code": {
"Location": "s3://lambda-functions.aws.amazon.com/{0}".format(self.code['S3Key']), "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(self.region, self.code['S3Key']),
"RepositoryType": "S3" "RepositoryType": "S3"
}, },
"Configuration": self.get_configuration(), "Configuration": self.get_configuration(),
} }
def convert(self, s): @staticmethod
def convert(s):
try: try:
return str(s, encoding='utf-8') return str(s, encoding='utf-8')
except: except:
return s return s
def is_json(self, test_str): @staticmethod
def is_json(test_str):
try: try:
response = json.loads(test_str) response = json.loads(test_str)
except: except:
response = test_str response = test_str
return response return response
def _invoke_lambda(self, code, event={}, context={}): def _invoke_lambda(self, code, event=None, context=None):
# TO DO: context not yet implemented # TODO: context not yet implemented
try: if event is None:
mycode = "\n".join(['import json', event = dict()
self.convert(self.code), if context is None:
self.convert('print(json.dumps(lambda_handler(%s, %s)))' % (self.is_json(self.convert(event)), context))]) context = {}
except Exception as ex:
print("Exception %s", ex)
errored = False
try: try:
original_stdout = sys.stdout # TODO: I believe we can keep the container running and feed events as needed
original_stderr = sys.stderr # also need to hook it up to the other services so it can make kws/s3 etc calls
codeOut = StringIO() # Should get invoke_id /RequestId from invovation
codeErr = StringIO() env_vars = {
sys.stdout = codeOut "AWS_LAMBDA_FUNCTION_TIMEOUT": self.timeout,
sys.stderr = codeErr "AWS_LAMBDA_FUNCTION_NAME": self.function_name,
exec(mycode) "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": self.memory_size,
exec_err = codeErr.getvalue() "AWS_LAMBDA_FUNCTION_VERSION": self.version,
exec_out = codeOut.getvalue() "AWS_REGION": self.region,
result = self.convert(exec_out.strip()) }
if exec_err:
result = "\n".join([exec_out.strip(), self.convert(exec_err)]) env_vars.update(self.environment_vars)
except Exception as ex:
errored = True container = output = exit_code = None
result = '%s\n\n\nException %s' % (mycode, ex) with _DockerDataVolumeContext(self) as data_vol:
finally: try:
codeErr.close() run_kwargs = dict(links={'motoserver': 'motoserver'}) if settings.TEST_SERVER_MODE else {}
codeOut.close() container = self.docker_client.containers.run(
sys.stdout = original_stdout "lambci/lambda:{}".format(self.run_time),
sys.stderr = original_stderr [self.handler, json.dumps(event)], remove=False,
return self.convert(result), errored mem_limit="{}m".format(self.memory_size),
volumes=["{}:/var/task".format(data_vol.name)], environment=env_vars, detach=True, **run_kwargs)
finally:
if container:
exit_code = container.wait()
output = container.logs(stdout=False, stderr=True)
output += container.logs(stdout=True, stderr=False)
container.remove()
output = output.decode('utf-8')
# Send output to "logs" backend
invoke_id = uuid.uuid4().hex
log_stream_name = "{date.year}/{date.month:02d}/{date.day:02d}/[{version}]{invoke_id}".format(
date=datetime.datetime.utcnow(), version=self.version, invoke_id=invoke_id
)
self.logs_backend.create_log_stream(self.logs_group_name, log_stream_name)
log_events = [{'timestamp': unix_time_millis(), "message": line}
for line in output.splitlines()]
self.logs_backend.put_log_events(self.logs_group_name, log_stream_name, log_events, None)
if exit_code != 0:
raise Exception(
'lambda invoke failed output: {}'.format(output))
# strip out RequestId lines
output = os.linesep.join([line for line in self.convert(output).splitlines() if not _stderr_regex.match(line)])
return output, False
except BaseException as e:
traceback.print_exc()
return "error running lambda: {}".format(e), True
def invoke(self, body, request_headers, response_headers): def invoke(self, body, request_headers, response_headers):
payload = dict() payload = dict()
if body:
body = json.loads(body)
# Get the invocation type: # Get the invocation type:
res, errored = self._invoke_lambda(code=self.code, event=body) res, errored = self._invoke_lambda(code=self.code, event=body)
if request_headers.get("x-amz-invocation-type") == "RequestResponse": if request_headers.get("x-amz-invocation-type") == "RequestResponse":
@ -180,7 +349,8 @@ class LambdaFunction(BaseModel):
return result return result
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
# required # required
@ -203,12 +373,21 @@ class LambdaFunction(BaseModel):
# this snippet converts this plaintext code to a proper base64-encoded ZIP file. # this snippet converts this plaintext code to a proper base64-encoded ZIP file.
if 'ZipFile' in properties['Code']: if 'ZipFile' in properties['Code']:
spec['Code']['ZipFile'] = base64.b64encode( spec['Code']['ZipFile'] = base64.b64encode(
cls._create_zipfile_from_plaintext_code(spec['Code']['ZipFile'])) cls._create_zipfile_from_plaintext_code(
spec['Code']['ZipFile']))
backend = lambda_backends[region_name] backend = lambda_backends[region_name]
fn = backend.create_function(spec) fn = backend.create_function(spec)
return fn return fn
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import \
UnformattedGetAttTemplateException
if attribute_name == 'Arn':
return 'arn:aws:lambda:{0}:123456789012:function:{1}'.format(
self.region, self.function_name)
raise UnformattedGetAttTemplateException()
@staticmethod @staticmethod
def _create_zipfile_from_plaintext_code(code): def _create_zipfile_from_plaintext_code(code):
zip_output = io.BytesIO() zip_output = io.BytesIO()
@ -219,33 +398,146 @@ class LambdaFunction(BaseModel):
return zip_output.read() return zip_output.read()
class LambdaBackend(BaseBackend): class EventSourceMapping(BaseModel):
def __init__(self, spec):
# required
self.function_name = spec['FunctionName']
self.event_source_arn = spec['EventSourceArn']
self.starting_position = spec['StartingPosition']
def __init__(self): # optional
self.batch_size = spec.get('BatchSize', 100)
self.enabled = spec.get('Enabled', True)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp',
None)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
spec = {
'FunctionName': properties['FunctionName'],
'EventSourceArn': properties['EventSourceArn'],
'StartingPosition': properties['StartingPosition']
}
optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]
return EventSourceMapping(spec)
class LambdaVersion(BaseModel):
def __init__(self, spec):
self.version = spec['Version']
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
spec = {
'Version': properties.get('Version')
}
return LambdaVersion(spec)
class LambdaBackend(BaseBackend):
def __init__(self, region_name):
self._functions = {} self._functions = {}
self.region_name = region_name
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def has_function(self, function_name): def has_function(self, function_name):
return function_name in self._functions return function_name in self._functions
def has_function_arn(self, function_arn):
return self.get_function_by_arn(function_arn) is not None
def create_function(self, spec): def create_function(self, spec):
fn = LambdaFunction(spec) fn = LambdaFunction(spec, self.region_name)
self._functions[fn.function_name] = fn self._functions[fn.function_name] = fn
return fn return fn
def get_function(self, function_name): def get_function(self, function_name):
return self._functions[function_name] return self._functions[function_name]
def get_function_by_arn(self, function_arn):
for function in self._functions.values():
if function.function_arn == function_arn:
return function
return None
def delete_function(self, function_name): def delete_function(self, function_name):
del self._functions[function_name] del self._functions[function_name]
def list_functions(self): def list_functions(self):
return self._functions.values() return self._functions.values()
def send_message(self, function_name, message):
event = {
"Records": [
{
"EventVersion": "1.0",
"EventSubscriptionArn": "arn:aws:sns:EXAMPLE",
"EventSource": "aws:sns",
"Sns": {
"SignatureVersion": "1",
"Timestamp": "1970-01-01T00:00:00.000Z",
"Signature": "EXAMPLE",
"SigningCertUrl": "EXAMPLE",
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
"Message": message,
"MessageAttributes": {
"Test": {
"Type": "String",
"Value": "TestString"
},
"TestBinary": {
"Type": "Binary",
"Value": "TestBinary"
}
},
"Type": "Notification",
"UnsubscribeUrl": "EXAMPLE",
"TopicArn": "arn:aws:sns:EXAMPLE",
"Subject": "TestInvoke"
}
}
]
}
self._functions[function_name].invoke(json.dumps(event), {}, {})
pass
def list_tags(self, resource):
return self.get_function_by_arn(resource).tags
def tag_resource(self, resource, tags):
self.get_function_by_arn(resource).tags.update(tags)
def untag_resource(self, resource, tagKeys):
function = self.get_function_by_arn(resource)
for key in tagKeys:
try:
del function.tags[key]
except KeyError:
pass
# Don't care
def add_policy(self, function_name, policy):
self.get_function(function_name).policy = policy
def do_validate_s3():
return os.environ.get('VALIDATE_LAMBDA_S3', '') in ['', '1', 'true']
lambda_backends = {}
for region in boto.awslambda.regions():
lambda_backends[region.name] = LambdaBackend()
# Handle us forgotten regions, unless Lambda truly only runs out of US and # Handle us forgotten regions, unless Lambda truly only runs out of US and
for region in ['ap-southeast-2']: lambda_backends = {_region.name: LambdaBackend(_region.name)
lambda_backends[region] = LambdaBackend() for _region in boto.awslambda.regions()}
lambda_backends['ap-southeast-2'] = LambdaBackend('ap-southeast-2')

View File

@ -3,6 +3,13 @@ from __future__ import unicode_literals
import json import json
import re import re
try:
from urllib import unquote
from urlparse import urlparse, parse_qs
except:
from urllib.parse import unquote, urlparse, parse_qs
from moto.core.utils import amz_crc32, amzn_request_id
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
@ -26,6 +33,8 @@ class LambdaResponse(BaseResponse):
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
@amz_crc32
@amzn_request_id
def invoke(self, request, full_url, headers): def invoke(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'POST': if request.method == 'POST':
@ -33,6 +42,55 @@ class LambdaResponse(BaseResponse):
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
@amz_crc32
@amzn_request_id
def invoke_async(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'POST':
return self._invoke_async(request, full_url)
else:
raise ValueError("Cannot handle request")
def tag(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
return self._list_tags(request, full_url)
elif request.method == 'POST':
return self._tag_resource(request, full_url)
elif request.method == 'DELETE':
return self._untag_resource(request, full_url)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def policy(self, request, full_url, headers):
if request.method == 'GET':
return self._get_policy(request, full_url, headers)
if request.method == 'POST':
return self._add_policy(request, full_url, headers)
def _add_policy(self, request, full_url, headers):
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_name = path.split('/')[-2]
if lambda_backend.has_function(function_name):
policy = request.body.decode('utf8')
lambda_backend.add_policy(function_name, policy)
return 200, {}, json.dumps(dict(Statement=policy))
else:
return 404, {}, "{}"
def _get_policy(self, request, full_url, headers):
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_name = path.split('/')[-2]
if lambda_backend.has_function(function_name):
function = lambda_backend.get_function(function_name)
return 200, {}, json.dumps(dict(Policy="{\"Statement\":[" + function.policy + "]}"))
else:
return 404, {}, "{}"
def _invoke(self, request, full_url): def _invoke(self, request, full_url):
response_headers = {} response_headers = {}
lambda_backend = self.get_lambda_backend(full_url) lambda_backend = self.get_lambda_backend(full_url)
@ -48,6 +106,20 @@ class LambdaResponse(BaseResponse):
else: else:
return 404, response_headers, "{}" return 404, response_headers, "{}"
def _invoke_async(self, request, full_url):
response_headers = {}
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_name = path.split('/')[-3]
if lambda_backend.has_function(function_name):
fn = lambda_backend.get_function(function_name)
fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(0)
return 202, response_headers, ""
else:
return 404, response_headers, "{}"
def _list_functions(self, request, full_url, headers): def _list_functions(self, request, full_url, headers):
lambda_backend = self.get_lambda_backend(full_url) lambda_backend = self.get_lambda_backend(full_url)
return 200, {}, json.dumps({ return 200, {}, json.dumps({
@ -102,3 +174,43 @@ class LambdaResponse(BaseResponse):
return region.group(1) return region.group(1)
else: else:
return self.default_region return self.default_region
def _list_tags(self, request, full_url):
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_arn = unquote(path.split('/')[-1])
if lambda_backend.has_function_arn(function_arn):
function = lambda_backend.get_function_by_arn(function_arn)
return 200, {}, json.dumps(dict(Tags=function.tags))
else:
return 404, {}, "{}"
def _tag_resource(self, request, full_url):
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_arn = unquote(path.split('/')[-1])
spec = json.loads(self.body)
if lambda_backend.has_function_arn(function_arn):
lambda_backend.tag_resource(function_arn, spec['Tags'])
return 200, {}, "{}"
else:
return 404, {}, "{}"
def _untag_resource(self, request, full_url):
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_arn = unquote(path.split('/')[-1].split('?')[0])
tag_keys = parse_qs(urlparse(full_url).query)['tagKeys']
if lambda_backend.has_function_arn(function_arn):
lambda_backend.untag_resource(function_arn, tag_keys)
return 204, {}, "{}"
else:
return 404, {}, "{}"

View File

@ -9,6 +9,9 @@ response = LambdaResponse()
url_paths = { url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root, '{0}/(?P<api_version>[^/]+)/functions/?$': response.root,
'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function, r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke, r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$': response.policy
} }

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.acm import acm_backends
from moto.apigateway import apigateway_backends from moto.apigateway import apigateway_backends
from moto.autoscaling import autoscaling_backends from moto.autoscaling import autoscaling_backends
from moto.awslambda import lambda_backends from moto.awslambda import lambda_backends
@ -10,8 +11,10 @@ from moto.datapipeline import datapipeline_backends
from moto.dynamodb import dynamodb_backends from moto.dynamodb import dynamodb_backends
from moto.dynamodb2 import dynamodb_backends2 from moto.dynamodb2 import dynamodb_backends2
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
from moto.ecr import ecr_backends
from moto.ecs import ecs_backends from moto.ecs import ecs_backends
from moto.elb import elb_backends from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
from moto.emr import emr_backends from moto.emr import emr_backends
from moto.events import events_backends from moto.events import events_backends
from moto.glacier import glacier_backends from moto.glacier import glacier_backends
@ -19,7 +22,9 @@ from moto.iam import iam_backends
from moto.instance_metadata import instance_metadata_backends from moto.instance_metadata import instance_metadata_backends
from moto.kinesis import kinesis_backends from moto.kinesis import kinesis_backends
from moto.kms import kms_backends from moto.kms import kms_backends
from moto.logs import logs_backends
from moto.opsworks import opsworks_backends from moto.opsworks import opsworks_backends
from moto.polly import polly_backends
from moto.rds2 import rds2_backends from moto.rds2 import rds2_backends
from moto.redshift import redshift_backends from moto.redshift import redshift_backends
from moto.route53 import route53_backends from moto.route53 import route53_backends
@ -29,27 +34,35 @@ from moto.sns import sns_backends
from moto.sqs import sqs_backends from moto.sqs import sqs_backends
from moto.ssm import ssm_backends from moto.ssm import ssm_backends
from moto.sts import sts_backends from moto.sts import sts_backends
from moto.xray import xray_backends
from moto.batch import batch_backends
BACKENDS = { BACKENDS = {
'acm': acm_backends,
'apigateway': apigateway_backends, 'apigateway': apigateway_backends,
'autoscaling': autoscaling_backends, 'autoscaling': autoscaling_backends,
'batch': batch_backends,
'cloudformation': cloudformation_backends, 'cloudformation': cloudformation_backends,
'cloudwatch': cloudwatch_backends, 'cloudwatch': cloudwatch_backends,
'datapipeline': datapipeline_backends, 'datapipeline': datapipeline_backends,
'dynamodb': dynamodb_backends, 'dynamodb': dynamodb_backends,
'dynamodb2': dynamodb_backends2, 'dynamodb2': dynamodb_backends2,
'ec2': ec2_backends, 'ec2': ec2_backends,
'ecr': ecr_backends,
'ecs': ecs_backends, 'ecs': ecs_backends,
'elb': elb_backends, 'elb': elb_backends,
'elbv2': elbv2_backends,
'events': events_backends, 'events': events_backends,
'emr': emr_backends, 'emr': emr_backends,
'glacier': glacier_backends, 'glacier': glacier_backends,
'iam': iam_backends, 'iam': iam_backends,
'moto_api': moto_api_backends, 'moto_api': moto_api_backends,
'instance_metadata': instance_metadata_backends, 'instance_metadata': instance_metadata_backends,
'opsworks': opsworks_backends, 'logs': logs_backends,
'kinesis': kinesis_backends, 'kinesis': kinesis_backends,
'kms': kms_backends, 'kms': kms_backends,
'opsworks': opsworks_backends,
'polly': polly_backends,
'redshift': redshift_backends, 'redshift': redshift_backends,
'rds': rds2_backends, 'rds': rds2_backends,
's3': s3_backends, 's3': s3_backends,
@ -61,6 +74,7 @@ BACKENDS = {
'sts': sts_backends, 'sts': sts_backends,
'route53': route53_backends, 'route53': route53_backends,
'lambda': lambda_backends, 'lambda': lambda_backends,
'xray': xray_backends
} }

6
moto/batch/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import batch_backends
from ..core.models import base_decorator
batch_backend = batch_backends['us-east-1']
mock_batch = base_decorator(batch_backends)

37
moto/batch/exceptions.py Normal file
View File

@ -0,0 +1,37 @@
from __future__ import unicode_literals
import json
class AWSError(Exception):
CODE = None
STATUS = 400
def __init__(self, message, code=None, status=None):
self.message = message
self.code = code if code is not None else self.CODE
self.status = status if status is not None else self.STATUS
def response(self):
return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status)
class InvalidRequestException(AWSError):
CODE = 'InvalidRequestException'
class InvalidParameterValueException(AWSError):
CODE = 'InvalidParameterValue'
class ValidationError(AWSError):
CODE = 'ValidationError'
class InternalFailure(AWSError):
CODE = 'InternalFailure'
STATUS = 500
class ClientException(AWSError):
CODE = 'ClientException'
STATUS = 400

1042
moto/batch/models.py Normal file

File diff suppressed because it is too large Load Diff

296
moto/batch/responses.py Normal file
View File

@ -0,0 +1,296 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from .models import batch_backends
from six.moves.urllib.parse import urlsplit
from .exceptions import AWSError
import json
class BatchResponse(BaseResponse):
def _error(self, code, message):
return json.dumps({'__type': code, 'message': message}), dict(status=400)
@property
def batch_backend(self):
"""
:return: Batch Backend
:rtype: moto.batch.models.BatchBackend
"""
return batch_backends[self.region]
@property
def json(self):
if self.body is None or self.body == '':
self._json = {}
elif not hasattr(self, '_json'):
try:
self._json = json.loads(self.body)
except json.JSONDecodeError:
print()
return self._json
def _get_param(self, param_name, if_none=None):
val = self.json.get(param_name)
if val is not None:
return val
return if_none
def _get_action(self):
# Return element after the /v1/*
return urlsplit(self.uri).path.lstrip('/').split('/')[1]
# CreateComputeEnvironment
def createcomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironmentName')
compute_resource = self._get_param('computeResources')
service_role = self._get_param('serviceRole')
state = self._get_param('state')
_type = self._get_param('type')
try:
name, arn = self.batch_backend.create_compute_environment(
compute_environment_name=compute_env_name,
_type=_type, state=state,
compute_resources=compute_resource,
service_role=service_role
)
except AWSError as err:
return err.response()
result = {
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
return json.dumps(result)
# DescribeComputeEnvironments
def describecomputeenvironments(self):
compute_environments = self._get_param('computeEnvironments')
max_results = self._get_param('maxResults') # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored
envs = self.batch_backend.describe_compute_environments(compute_environments, max_results=max_results, next_token=next_token)
result = {'computeEnvironments': envs}
return json.dumps(result)
# DeleteComputeEnvironment
def deletecomputeenvironment(self):
compute_environment = self._get_param('computeEnvironment')
try:
self.batch_backend.delete_compute_environment(compute_environment)
except AWSError as err:
return err.response()
return ''
# UpdateComputeEnvironment
def updatecomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironment')
compute_resource = self._get_param('computeResources')
service_role = self._get_param('serviceRole')
state = self._get_param('state')
try:
name, arn = self.batch_backend.update_compute_environment(
compute_environment_name=compute_env_name,
compute_resources=compute_resource,
service_role=service_role,
state=state
)
except AWSError as err:
return err.response()
result = {
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
return json.dumps(result)
# CreateJobQueue
def createjobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder')
queue_name = self._get_param('jobQueueName')
priority = self._get_param('priority')
state = self._get_param('state')
try:
name, arn = self.batch_backend.create_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order
)
except AWSError as err:
return err.response()
result = {
'jobQueueArn': arn,
'jobQueueName': name
}
return json.dumps(result)
# DescribeJobQueues
def describejobqueues(self):
job_queues = self._get_param('jobQueues')
max_results = self._get_param('maxResults') # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored
queues = self.batch_backend.describe_job_queues(job_queues, max_results=max_results, next_token=next_token)
result = {'jobQueues': queues}
return json.dumps(result)
# UpdateJobQueue
def updatejobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder')
queue_name = self._get_param('jobQueue')
priority = self._get_param('priority')
state = self._get_param('state')
try:
name, arn = self.batch_backend.update_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order
)
except AWSError as err:
return err.response()
result = {
'jobQueueArn': arn,
'jobQueueName': name
}
return json.dumps(result)
# DeleteJobQueue
def deletejobqueue(self):
queue_name = self._get_param('jobQueue')
self.batch_backend.delete_job_queue(queue_name)
return ''
# RegisterJobDefinition
def registerjobdefinition(self):
container_properties = self._get_param('containerProperties')
def_name = self._get_param('jobDefinitionName')
parameters = self._get_param('parameters')
retry_strategy = self._get_param('retryStrategy')
_type = self._get_param('type')
try:
name, arn, revision = self.batch_backend.register_job_definition(
def_name=def_name,
parameters=parameters,
_type=_type,
retry_strategy=retry_strategy,
container_properties=container_properties
)
except AWSError as err:
return err.response()
result = {
'jobDefinitionArn': arn,
'jobDefinitionName': name,
'revision': revision
}
return json.dumps(result)
# DeregisterJobDefinition
def deregisterjobdefinition(self):
queue_name = self._get_param('jobDefinition')
self.batch_backend.deregister_job_definition(queue_name)
return ''
# DescribeJobDefinitions
def describejobdefinitions(self):
job_def_name = self._get_param('jobDefinitionName')
job_def_list = self._get_param('jobDefinitions')
max_results = self._get_param('maxResults')
next_token = self._get_param('nextToken')
status = self._get_param('status')
job_defs = self.batch_backend.describe_job_definitions(job_def_name, job_def_list, status, max_results, next_token)
result = {'jobDefinitions': [job.describe() for job in job_defs]}
return json.dumps(result)
# SubmitJob
def submitjob(self):
container_overrides = self._get_param('containerOverrides')
depends_on = self._get_param('dependsOn')
job_def = self._get_param('jobDefinition')
job_name = self._get_param('jobName')
job_queue = self._get_param('jobQueue')
parameters = self._get_param('parameters')
retries = self._get_param('retryStrategy')
try:
name, job_id = self.batch_backend.submit_job(
job_name, job_def, job_queue,
parameters=parameters,
retries=retries,
depends_on=depends_on,
container_overrides=container_overrides
)
except AWSError as err:
return err.response()
result = {
'jobId': job_id,
'jobName': name,
}
return json.dumps(result)
# DescribeJobs
def describejobs(self):
jobs = self._get_param('jobs')
try:
return json.dumps({'jobs': self.batch_backend.describe_jobs(jobs)})
except AWSError as err:
return err.response()
# ListJobs
def listjobs(self):
job_queue = self._get_param('jobQueue')
job_status = self._get_param('jobStatus')
max_results = self._get_param('maxResults')
next_token = self._get_param('nextToken')
try:
jobs = self.batch_backend.list_jobs(job_queue, job_status, max_results, next_token)
except AWSError as err:
return err.response()
result = {'jobSummaryList': [{'jobId': job.job_id, 'jobName': job.job_name} for job in jobs]}
return json.dumps(result)
# TerminateJob
def terminatejob(self):
job_id = self._get_param('jobId')
reason = self._get_param('reason')
try:
self.batch_backend.terminate_job(job_id, reason)
except AWSError as err:
return err.response()
return ''
# CancelJob
def canceljob(self): # Theres some AWS semantics on the differences but for us they're identical ;-)
return self.terminatejob()

25
moto/batch/urls.py Normal file
View File

@ -0,0 +1,25 @@
from __future__ import unicode_literals
from .responses import BatchResponse
url_bases = [
"https?://batch.(.+).amazonaws.com",
]
url_paths = {
'{0}/v1/createcomputeenvironment$': BatchResponse.dispatch,
'{0}/v1/describecomputeenvironments$': BatchResponse.dispatch,
'{0}/v1/deletecomputeenvironment': BatchResponse.dispatch,
'{0}/v1/updatecomputeenvironment': BatchResponse.dispatch,
'{0}/v1/createjobqueue': BatchResponse.dispatch,
'{0}/v1/describejobqueues': BatchResponse.dispatch,
'{0}/v1/updatejobqueue': BatchResponse.dispatch,
'{0}/v1/deletejobqueue': BatchResponse.dispatch,
'{0}/v1/registerjobdefinition': BatchResponse.dispatch,
'{0}/v1/deregisterjobdefinition': BatchResponse.dispatch,
'{0}/v1/describejobdefinitions': BatchResponse.dispatch,
'{0}/v1/submitjob': BatchResponse.dispatch,
'{0}/v1/describejobs': BatchResponse.dispatch,
'{0}/v1/listjobs': BatchResponse.dispatch,
'{0}/v1/terminatejob': BatchResponse.dispatch,
'{0}/v1/canceljob': BatchResponse.dispatch,
}

22
moto/batch/utils.py Normal file
View File

@ -0,0 +1,22 @@
from __future__ import unicode_literals
def make_arn_for_compute_env(account_id, name, region_name):
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(region_name, account_id, name)
def make_arn_for_job_queue(account_id, name, region_name):
return "arn:aws:batch:{0}:{1}:job-queue/{2}".format(region_name, account_id, name)
def make_arn_for_task_def(account_id, name, revision, region_name):
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(region_name, account_id, name, revision)
def lowercase_first_key(some_dict):
new_dict = {}
for key, value in some_dict.items():
new_key = key[0].lower() + key[1:]
new_dict[new_key] = value
return new_dict

View File

@ -9,13 +9,13 @@ from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from .parsing import ResourceMap, OutputMap from .parsing import ResourceMap, OutputMap
from .utils import generate_stack_id from .utils import generate_stack_id, yaml_tag_constructor
from .exceptions import ValidationError from .exceptions import ValidationError
class FakeStack(BaseModel): class FakeStack(BaseModel):
def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None): def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None):
self.stack_id = stack_id self.stack_id = stack_id
self.name = name self.name = name
self.template = template self.template = template
@ -30,6 +30,7 @@ class FakeStack(BaseModel):
resource_status_reason="User Initiated") resource_status_reason="User Initiated")
self.description = self.template_dict.get('Description') self.description = self.template_dict.get('Description')
self.cross_stack_resources = cross_stack_resources or []
self.resource_map = self._create_resource_map() self.resource_map = self._create_resource_map()
self.output_map = self._create_output_map() self.output_map = self._create_output_map()
self._add_stack_event("CREATE_COMPLETE") self._add_stack_event("CREATE_COMPLETE")
@ -37,12 +38,12 @@ class FakeStack(BaseModel):
def _create_resource_map(self): def _create_resource_map(self):
resource_map = ResourceMap( resource_map = ResourceMap(
self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict) self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict, self.cross_stack_resources)
resource_map.create() resource_map.create()
return resource_map return resource_map
def _create_output_map(self): def _create_output_map(self):
output_map = OutputMap(self.resource_map, self.template_dict) output_map = OutputMap(self.resource_map, self.template_dict, self.stack_id)
output_map.create() output_map.create()
return output_map return output_map
@ -73,6 +74,7 @@ class FakeStack(BaseModel):
)) ))
def _parse_template(self): def _parse_template(self):
yaml.add_multi_constructor('', yaml_tag_constructor)
try: try:
self.template_dict = yaml.load(self.template) self.template_dict = yaml.load(self.template)
except yaml.parser.ParserError: except yaml.parser.ParserError:
@ -90,6 +92,10 @@ class FakeStack(BaseModel):
def stack_outputs(self): def stack_outputs(self):
return self.output_map.values() return self.output_map.values()
@property
def exports(self):
return self.output_map.exports
def update(self, template, role_arn=None, parameters=None, tags=None): def update(self, template, role_arn=None, parameters=None, tags=None):
self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated") self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated")
self.template = template self.template = template
@ -131,6 +137,7 @@ class CloudFormationBackend(BaseBackend):
def __init__(self): def __init__(self):
self.stacks = OrderedDict() self.stacks = OrderedDict()
self.deleted_stacks = {} self.deleted_stacks = {}
self.exports = OrderedDict()
def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None): def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None):
stack_id = generate_stack_id(name) stack_id = generate_stack_id(name)
@ -143,8 +150,12 @@ class CloudFormationBackend(BaseBackend):
notification_arns=notification_arns, notification_arns=notification_arns,
tags=tags, tags=tags,
role_arn=role_arn, role_arn=role_arn,
cross_stack_resources=self.exports,
) )
self.stacks[stack_id] = new_stack self.stacks[stack_id] = new_stack
self._validate_export_uniqueness(new_stack)
for export in new_stack.exports:
self.exports[export.name] = export
return new_stack return new_stack
def describe_stacks(self, name_or_stack_id): def describe_stacks(self, name_or_stack_id):
@ -191,6 +202,7 @@ class CloudFormationBackend(BaseBackend):
stack = self.stacks.pop(name_or_stack_id, None) stack = self.stacks.pop(name_or_stack_id, None)
stack.delete() stack.delete()
self.deleted_stacks[stack.stack_id] = stack self.deleted_stacks[stack.stack_id] = stack
[self.exports.pop(export.name) for export in stack.exports]
return self.stacks.pop(name_or_stack_id, None) return self.stacks.pop(name_or_stack_id, None)
else: else:
# Delete by stack name # Delete by stack name
@ -198,6 +210,23 @@ class CloudFormationBackend(BaseBackend):
if stack.name == name_or_stack_id: if stack.name == name_or_stack_id:
self.delete_stack(stack.stack_id) self.delete_stack(stack.stack_id)
def list_exports(self, token):
all_exports = list(self.exports.values())
if token is None:
exports = all_exports[0:100]
next_token = '100' if len(all_exports) > 100 else None
else:
token = int(token)
exports = all_exports[token:token + 100]
next_token = str(token + 100) if len(all_exports) > token + 100 else None
return exports, next_token
def _validate_export_uniqueness(self, stack):
new_stack_export_names = [x.name for x in stack.exports]
export_names = self.exports.keys()
if not set(export_names).isdisjoint(new_stack_export_names):
raise ValidationError(stack.stack_id, message='Export names must be unique across a given region')
cloudformation_backends = {} cloudformation_backends = {}
for region in boto.cloudformation.regions(): for region in boto.cloudformation.regions():

View File

@ -4,14 +4,19 @@ import functools
import logging import logging
import copy import copy
import warnings import warnings
import re
from moto.autoscaling import models as autoscaling_models from moto.autoscaling import models as autoscaling_models
from moto.awslambda import models as lambda_models from moto.awslambda import models as lambda_models
from moto.batch import models as batch_models
from moto.cloudwatch import models as cloudwatch_models
from moto.datapipeline import models as datapipeline_models from moto.datapipeline import models as datapipeline_models
from moto.dynamodb import models as dynamodb_models
from moto.ec2 import models as ec2_models from moto.ec2 import models as ec2_models
from moto.ecs import models as ecs_models from moto.ecs import models as ecs_models
from moto.elb import models as elb_models from moto.elb import models as elb_models
from moto.iam import models as iam_models from moto.iam import models as iam_models
from moto.kinesis import models as kinesis_models
from moto.kms import models as kms_models from moto.kms import models as kms_models
from moto.rds import models as rds_models from moto.rds import models as rds_models
from moto.rds2 import models as rds2_models from moto.rds2 import models as rds2_models
@ -27,7 +32,14 @@ from boto.cloudformation.stack import Output
MODEL_MAP = { MODEL_MAP = {
"AWS::AutoScaling::AutoScalingGroup": autoscaling_models.FakeAutoScalingGroup, "AWS::AutoScaling::AutoScalingGroup": autoscaling_models.FakeAutoScalingGroup,
"AWS::AutoScaling::LaunchConfiguration": autoscaling_models.FakeLaunchConfiguration, "AWS::AutoScaling::LaunchConfiguration": autoscaling_models.FakeLaunchConfiguration,
"AWS::Batch::JobDefinition": batch_models.JobDefinition,
"AWS::Batch::JobQueue": batch_models.JobQueue,
"AWS::Batch::ComputeEnvironment": batch_models.ComputeEnvironment,
"AWS::DynamoDB::Table": dynamodb_models.Table,
"AWS::Kinesis::Stream": kinesis_models.Stream,
"AWS::Lambda::EventSourceMapping": lambda_models.EventSourceMapping,
"AWS::Lambda::Function": lambda_models.LambdaFunction, "AWS::Lambda::Function": lambda_models.LambdaFunction,
"AWS::Lambda::Version": lambda_models.LambdaVersion,
"AWS::EC2::EIP": ec2_models.ElasticAddress, "AWS::EC2::EIP": ec2_models.ElasticAddress,
"AWS::EC2::Instance": ec2_models.Instance, "AWS::EC2::Instance": ec2_models.Instance,
"AWS::EC2::InternetGateway": ec2_models.InternetGateway, "AWS::EC2::InternetGateway": ec2_models.InternetGateway,
@ -53,6 +65,7 @@ MODEL_MAP = {
"AWS::IAM::InstanceProfile": iam_models.InstanceProfile, "AWS::IAM::InstanceProfile": iam_models.InstanceProfile,
"AWS::IAM::Role": iam_models.Role, "AWS::IAM::Role": iam_models.Role,
"AWS::KMS::Key": kms_models.Key, "AWS::KMS::Key": kms_models.Key,
"AWS::Logs::LogGroup": cloudwatch_models.LogGroup,
"AWS::RDS::DBInstance": rds_models.Database, "AWS::RDS::DBInstance": rds_models.Database,
"AWS::RDS::DBSecurityGroup": rds_models.SecurityGroup, "AWS::RDS::DBSecurityGroup": rds_models.SecurityGroup,
"AWS::RDS::DBSubnetGroup": rds_models.SubnetGroup, "AWS::RDS::DBSubnetGroup": rds_models.SubnetGroup,
@ -133,7 +146,7 @@ def clean_json(resource_json, resources_map):
try: try:
return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1]) return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1])
except NotImplementedError as n: except NotImplementedError as n:
logger.warning(n.message.format( logger.warning(str(n).format(
resource_json['Fn::GetAtt'][0])) resource_json['Fn::GetAtt'][0]))
except UnformattedGetAttTemplateException: except UnformattedGetAttTemplateException:
raise ValidationError( raise ValidationError(
@ -149,12 +162,42 @@ def clean_json(resource_json, resources_map):
return clean_json(false_value, resources_map) return clean_json(false_value, resources_map)
if 'Fn::Join' in resource_json: if 'Fn::Join' in resource_json:
join_list = [] join_list = clean_json(resource_json['Fn::Join'][1], resources_map)
for val in resource_json['Fn::Join'][1]: return resource_json['Fn::Join'][0].join([str(x) for x in join_list])
cleaned_val = clean_json(val, resources_map)
join_list.append('{0}'.format(cleaned_val) if 'Fn::Split' in resource_json:
if cleaned_val else '{0}'.format(val)) to_split = clean_json(resource_json['Fn::Split'][1], resources_map)
return resource_json['Fn::Join'][0].join(join_list) return to_split.split(resource_json['Fn::Split'][0])
if 'Fn::Select' in resource_json:
select_index = int(resource_json['Fn::Select'][0])
select_list = clean_json(resource_json['Fn::Select'][1], resources_map)
return select_list[select_index]
if 'Fn::Sub' in resource_json:
if isinstance(resource_json['Fn::Sub'], list):
warnings.warn(
"Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation")
else:
fn_sub_value = clean_json(resource_json['Fn::Sub'], resources_map)
to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value)
literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value)
for sub in to_sub:
if '.' in sub:
cleaned_ref = clean_json({'Fn::GetAtt': re.findall('(?<=\${)[^"]*?(?=})', sub)[0].split('.')}, resources_map)
else:
cleaned_ref = clean_json({'Ref': re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, resources_map)
fn_sub_value = fn_sub_value.replace(sub, cleaned_ref)
for literal in literals:
fn_sub_value = fn_sub_value.replace(literal, literal.replace('!', ''))
return fn_sub_value
pass
if 'Fn::ImportValue' in resource_json:
cleaned_val = clean_json(resource_json['Fn::ImportValue'], resources_map)
values = [x.value for x in resources_map.cross_stack_resources.values() if x.name == cleaned_val]
if any(values):
return values[0]
cleaned_json = {} cleaned_json = {}
for key, value in resource_json.items(): for key, value in resource_json.items():
@ -295,13 +338,14 @@ class ResourceMap(collections.Mapping):
each resources is passed this lazy map that it can grab dependencies from. each resources is passed this lazy map that it can grab dependencies from.
""" """
def __init__(self, stack_id, stack_name, parameters, tags, region_name, template): def __init__(self, stack_id, stack_name, parameters, tags, region_name, template, cross_stack_resources):
self._template = template self._template = template
self._resource_json_map = template['Resources'] self._resource_json_map = template['Resources']
self._region_name = region_name self._region_name = region_name
self.input_parameters = parameters self.input_parameters = parameters
self.tags = copy.deepcopy(tags) self.tags = copy.deepcopy(tags)
self.resolved_parameters = {} self.resolved_parameters = {}
self.cross_stack_resources = cross_stack_resources
# Create the default resources # Create the default resources
self._parsed_resources = { self._parsed_resources = {
@ -454,8 +498,9 @@ class ResourceMap(collections.Mapping):
class OutputMap(collections.Mapping): class OutputMap(collections.Mapping):
def __init__(self, resources, template): def __init__(self, resources, template, stack_id):
self._template = template self._template = template
self._stack_id = stack_id
self._output_json_map = template.get('Outputs') self._output_json_map = template.get('Outputs')
# Create the default resources # Create the default resources
@ -484,6 +529,37 @@ class OutputMap(collections.Mapping):
def outputs(self): def outputs(self):
return self._output_json_map.keys() if self._output_json_map else [] return self._output_json_map.keys() if self._output_json_map else []
@property
def exports(self):
exports = []
if self.outputs:
for key, value in self._output_json_map.items():
if value.get('Export'):
cleaned_name = clean_json(value['Export'].get('Name'), self._resource_map)
cleaned_value = clean_json(value.get('Value'), self._resource_map)
exports.append(Export(self._stack_id, cleaned_name, cleaned_value))
return exports
def create(self): def create(self):
for output in self.outputs: for output in self.outputs:
self[output] self[output]
class Export(object):
def __init__(self, exporting_stack_id, name, value):
self._exporting_stack_id = exporting_stack_id
self._name = name
self._value = value
@property
def exporting_stack_id(self):
return self._exporting_stack_id
@property
def name(self):
return self._name
@property
def value(self):
return self._value

View File

@ -210,6 +210,12 @@ class CloudFormationResponse(BaseResponse):
template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE) template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE)
return template.render() return template.render()
def list_exports(self):
token = self._get_param('NextToken')
exports, next_token = self.cloudformation_backend.list_exports(token=token)
template = self.response_template(LIST_EXPORTS_RESPONSE)
return template.render(exports=exports, next_token=next_token)
CREATE_STACK_RESPONSE_TEMPLATE = """<CreateStackResponse> CREATE_STACK_RESPONSE_TEMPLATE = """<CreateStackResponse>
<CreateStackResult> <CreateStackResult>
@ -385,8 +391,7 @@ LIST_STACKS_RESOURCES_RESPONSE = """<ListStackResourcesResponse>
GET_TEMPLATE_RESPONSE_TEMPLATE = """<GetTemplateResponse> GET_TEMPLATE_RESPONSE_TEMPLATE = """<GetTemplateResponse>
<GetTemplateResult> <GetTemplateResult>
<TemplateBody>{{ stack.template }} <TemplateBody>{{ stack.template }}</TemplateBody>
</TemplateBody>
</GetTemplateResult> </GetTemplateResult>
<ResponseMetadata> <ResponseMetadata>
<RequestId>b9b4b068-3a41-11e5-94eb-example</RequestId> <RequestId>b9b4b068-3a41-11e5-94eb-example</RequestId>
@ -410,3 +415,23 @@ DELETE_STACK_RESPONSE_TEMPLATE = """<DeleteStackResponse>
</ResponseMetadata> </ResponseMetadata>
</DeleteStackResponse> </DeleteStackResponse>
""" """
LIST_EXPORTS_RESPONSE = """<ListExportsResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/">
<ListExportsResult>
<Exports>
{% for export in exports %}
<member>
<ExportingStackId>{{ export.exporting_stack_id }}</ExportingStackId>
<Name>{{ export.name }}</Name>
<Value>{{ export.value }}</Value>
</member>
{% endfor %}
</Exports>
{% if next_token %}
<NextToken>{{ next_token }}</NextToken>
{% endif %}
</ListExportsResult>
<ResponseMetadata>
<RequestId>5ccc7dcd-744c-11e5-be70-example</RequestId>
</ResponseMetadata>
</ListExportsResponse>"""

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import uuid import uuid
import six import six
import random import random
import yaml
def generate_stack_id(stack_name): def generate_stack_id(stack_name):
@ -13,3 +14,22 @@ def random_suffix():
size = 12 size = 12
chars = list(range(10)) + ['A-Z'] chars = list(range(10)) + ['A-Z']
return ''.join(six.text_type(random.choice(chars)) for x in range(size)) return ''.join(six.text_type(random.choice(chars)) for x in range(size))
def yaml_tag_constructor(loader, tag, node):
"""convert shorthand intrinsic function to full name
"""
def _f(loader, tag, node):
if tag == '!GetAtt':
return node.value.split('.')
elif type(node) == yaml.SequenceNode:
return loader.construct_sequence(node)
else:
return node.value
if tag == '!Ref':
key = 'Ref'
else:
key = 'Fn::{}'.format(tag[1:])
return {key: _f(loader, tag, node)}

View File

@ -2,6 +2,11 @@ from moto.core import BaseBackend, BaseModel
import boto.ec2.cloudwatch import boto.ec2.cloudwatch
import datetime import datetime
from .utils import make_arn_for_dashboard
DEFAULT_ACCOUNT_ID = 123456789012
class Dimension(object): class Dimension(object):
@ -44,10 +49,34 @@ class MetricDatum(BaseModel):
'value']) for dimension in dimensions] 'value']) for dimension in dimensions]
class Dashboard(BaseModel):
def __init__(self, name, body):
# Guaranteed to be unique for now as the name is also the key of a dictionary where they are stored
self.arn = make_arn_for_dashboard(DEFAULT_ACCOUNT_ID, name)
self.name = name
self.body = body
self.last_modified = datetime.datetime.now()
@property
def last_modified_iso(self):
return self.last_modified.isoformat()
@property
def size(self):
return len(self)
def __len__(self):
return len(self.body)
def __repr__(self):
return '<CloudWatchDashboard {0}>'.format(self.name)
class CloudWatchBackend(BaseBackend): class CloudWatchBackend(BaseBackend):
def __init__(self): def __init__(self):
self.alarms = {} self.alarms = {}
self.dashboards = {}
self.metric_data = [] self.metric_data = []
def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods, def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods,
@ -110,6 +139,52 @@ class CloudWatchBackend(BaseBackend):
def get_all_metrics(self): def get_all_metrics(self):
return self.metric_data return self.metric_data
def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body)
def list_dashboards(self, prefix=''):
for key, value in self.dashboards.items():
if key.startswith(prefix):
yield value
def delete_dashboards(self, dashboards):
to_delete = set(dashboards)
all_dashboards = set(self.dashboards.keys())
left_over = to_delete - all_dashboards
if len(left_over) > 0:
# Some dashboards are not found
return False, 'The specified dashboard does not exist. [{0}]'.format(', '.join(left_over))
for dashboard in to_delete:
del self.dashboards[dashboard]
return True, None
def get_dashboard(self, dashboard):
return self.dashboards.get(dashboard)
class LogGroup(BaseModel):
def __init__(self, spec):
# required
self.name = spec['LogGroupName']
# optional
self.tags = spec.get('Tags', [])
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
spec = {
'LogGroupName': properties['LogGroupName']
}
optional_properties = 'Tags'.split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]
return LogGroup(spec)
cloudwatch_backends = {} cloudwatch_backends = {}
for region in boto.ec2.cloudwatch.regions(): for region in boto.ec2.cloudwatch.regions():

View File

@ -1,9 +1,18 @@
import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import cloudwatch_backends from .models import cloudwatch_backends
class CloudWatchResponse(BaseResponse): class CloudWatchResponse(BaseResponse):
@property
def cloudwatch_backend(self):
return cloudwatch_backends[self.region]
def _error(self, code, message, status=400):
template = self.response_template(ERROR_RESPONSE_TEMPLATE)
return template.render(code=code, message=message), dict(status=status)
def put_metric_alarm(self): def put_metric_alarm(self):
name = self._get_param('AlarmName') name = self._get_param('AlarmName')
namespace = self._get_param('Namespace') namespace = self._get_param('Namespace')
@ -20,15 +29,14 @@ class CloudWatchResponse(BaseResponse):
insufficient_data_actions = self._get_multi_param( insufficient_data_actions = self._get_multi_param(
"InsufficientDataActions.member") "InsufficientDataActions.member")
unit = self._get_param('Unit') unit = self._get_param('Unit')
cloudwatch_backend = cloudwatch_backends[self.region] alarm = self.cloudwatch_backend.put_metric_alarm(name, namespace, metric_name,
alarm = cloudwatch_backend.put_metric_alarm(name, namespace, metric_name, comparison_operator,
comparison_operator, evaluation_periods, period,
evaluation_periods, period, threshold, statistic,
threshold, statistic, description, dimensions,
description, dimensions, alarm_actions, ok_actions,
alarm_actions, ok_actions, insufficient_data_actions,
insufficient_data_actions, unit)
unit)
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE) template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm) return template.render(alarm=alarm)
@ -37,28 +45,26 @@ class CloudWatchResponse(BaseResponse):
alarm_name_prefix = self._get_param('AlarmNamePrefix') alarm_name_prefix = self._get_param('AlarmNamePrefix')
alarm_names = self._get_multi_param('AlarmNames.member') alarm_names = self._get_multi_param('AlarmNames.member')
state_value = self._get_param('StateValue') state_value = self._get_param('StateValue')
cloudwatch_backend = cloudwatch_backends[self.region]
if action_prefix: if action_prefix:
alarms = cloudwatch_backend.get_alarms_by_action_prefix( alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(
action_prefix) action_prefix)
elif alarm_name_prefix: elif alarm_name_prefix:
alarms = cloudwatch_backend.get_alarms_by_alarm_name_prefix( alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix(
alarm_name_prefix) alarm_name_prefix)
elif alarm_names: elif alarm_names:
alarms = cloudwatch_backend.get_alarms_by_alarm_names(alarm_names) alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names)
elif state_value: elif state_value:
alarms = cloudwatch_backend.get_alarms_by_state_value(state_value) alarms = self.cloudwatch_backend.get_alarms_by_state_value(state_value)
else: else:
alarms = cloudwatch_backend.get_all_alarms() alarms = self.cloudwatch_backend.get_all_alarms()
template = self.response_template(DESCRIBE_ALARMS_TEMPLATE) template = self.response_template(DESCRIBE_ALARMS_TEMPLATE)
return template.render(alarms=alarms) return template.render(alarms=alarms)
def delete_alarms(self): def delete_alarms(self):
alarm_names = self._get_multi_param('AlarmNames.member') alarm_names = self._get_multi_param('AlarmNames.member')
cloudwatch_backend = cloudwatch_backends[self.region] self.cloudwatch_backend.delete_alarms(alarm_names)
cloudwatch_backend.delete_alarms(alarm_names)
template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE) template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE)
return template.render() return template.render()
@ -89,17 +95,77 @@ class CloudWatchResponse(BaseResponse):
dimension_index += 1 dimension_index += 1
metric_data.append([metric_name, value, dimensions]) metric_data.append([metric_name, value, dimensions])
metric_index += 1 metric_index += 1
cloudwatch_backend = cloudwatch_backends[self.region] self.cloudwatch_backend.put_metric_data(namespace, metric_data)
cloudwatch_backend.put_metric_data(namespace, metric_data)
template = self.response_template(PUT_METRIC_DATA_TEMPLATE) template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
return template.render() return template.render()
def list_metrics(self): def list_metrics(self):
cloudwatch_backend = cloudwatch_backends[self.region] metrics = self.cloudwatch_backend.get_all_metrics()
metrics = cloudwatch_backend.get_all_metrics()
template = self.response_template(LIST_METRICS_TEMPLATE) template = self.response_template(LIST_METRICS_TEMPLATE)
return template.render(metrics=metrics) return template.render(metrics=metrics)
def delete_dashboards(self):
dashboards = self._get_multi_param('DashboardNames.member')
if dashboards is None:
return self._error('InvalidParameterValue', 'Need at least 1 dashboard')
status, error = self.cloudwatch_backend.delete_dashboards(dashboards)
if not status:
return self._error('ResourceNotFound', error)
template = self.response_template(DELETE_DASHBOARD_TEMPLATE)
return template.render()
def describe_alarm_history(self):
raise NotImplementedError()
def describe_alarms_for_metric(self):
raise NotImplementedError()
def disable_alarm_actions(self):
raise NotImplementedError()
def enable_alarm_actions(self):
raise NotImplementedError()
def get_dashboard(self):
dashboard_name = self._get_param('DashboardName')
dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name)
if dashboard is None:
return self._error('ResourceNotFound', 'Dashboard does not exist')
template = self.response_template(GET_DASHBOARD_TEMPLATE)
return template.render(dashboard=dashboard)
def get_metric_statistics(self):
raise NotImplementedError()
def list_dashboards(self):
prefix = self._get_param('DashboardNamePrefix', '')
dashboards = self.cloudwatch_backend.list_dashboards(prefix)
template = self.response_template(LIST_DASHBOARD_RESPONSE)
return template.render(dashboards=dashboards)
def put_dashboard(self):
name = self._get_param('DashboardName')
body = self._get_param('DashboardBody')
try:
json.loads(body)
except ValueError:
return self._error('InvalidParameterInput', 'Body is invalid JSON')
self.cloudwatch_backend.put_dashboard(name, body)
template = self.response_template(PUT_DASHBOARD_RESPONSE)
return template.render()
def set_alarm_state(self):
raise NotImplementedError()
PUT_METRIC_ALARM_TEMPLATE = """<PutMetricAlarmResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/"> PUT_METRIC_ALARM_TEMPLATE = """<PutMetricAlarmResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata> <ResponseMetadata>
@ -199,3 +265,58 @@ LIST_METRICS_TEMPLATE = """<ListMetricsResponse xmlns="http://monitoring.amazona
</NextToken> </NextToken>
</ListMetricsResult> </ListMetricsResult>
</ListMetricsResponse>""" </ListMetricsResponse>"""
PUT_DASHBOARD_RESPONSE = """<PutDashboardResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<PutDashboardResult>
<DashboardValidationMessages/>
</PutDashboardResult>
<ResponseMetadata>
<RequestId>44b1d4d8-9fa3-11e7-8ad3-41b86ac5e49e</RequestId>
</ResponseMetadata>
</PutDashboardResponse>"""
LIST_DASHBOARD_RESPONSE = """<ListDashboardsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ListDashboardsResult>
<DashboardEntries>
{% for dashboard in dashboards %}
<member>
<DashboardArn>{{ dashboard.arn }}</DashboardArn>
<LastModified>{{ dashboard.last_modified_iso }}</LastModified>
<Size>{{ dashboard.size }}</Size>
<DashboardName>{{ dashboard.name }}</DashboardName>
</member>
{% endfor %}
</DashboardEntries>
</ListDashboardsResult>
<ResponseMetadata>
<RequestId>c3773873-9fa5-11e7-b315-31fcc9275d62</RequestId>
</ResponseMetadata>
</ListDashboardsResponse>"""
DELETE_DASHBOARD_TEMPLATE = """<DeleteDashboardsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<DeleteDashboardsResult/>
<ResponseMetadata>
<RequestId>68d1dc8c-9faa-11e7-a694-df2715690df2</RequestId>
</ResponseMetadata>
</DeleteDashboardsResponse>"""
GET_DASHBOARD_TEMPLATE = """<GetDashboardResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<GetDashboardResult>
<DashboardArn>{{ dashboard.arn }}</DashboardArn>
<DashboardBody>{{ dashboard.body }}</DashboardBody>
<DashboardName>{{ dashboard.name }}</DashboardName>
</GetDashboardResult>
<ResponseMetadata>
<RequestId>e3c16bb0-9faa-11e7-b315-31fcc9275d62</RequestId>
</ResponseMetadata>
</GetDashboardResponse>
"""
ERROR_RESPONSE_TEMPLATE = """<ErrorResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<Error>
<Type>Sender</Type>
<Code>{{ code }}</Code>
<Message>{{ message }}</Message>
</Error>
<RequestId>5e45fd1e-9fa3-11e7-b720-89e8821d38c4</RequestId>
</ErrorResponse>"""

5
moto/cloudwatch/utils.py Normal file
View File

@ -0,0 +1,5 @@
from __future__ import unicode_literals
def make_arn_for_dashboard(account_id, name):
return "arn:aws:cloudwatch::{0}dashboard/{1}".format(account_id, name)

View File

@ -167,7 +167,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
match = re.search(self.region_regex, full_url) match = re.search(self.region_regex, full_url)
if match: if match:
region = match.group(1) region = match.group(1)
elif 'Authorization' in request.headers: elif 'Authorization' in request.headers and 'AWS4' in request.headers['Authorization']:
region = request.headers['Authorization'].split(",")[ region = request.headers['Authorization'].split(",")[
0].split("/")[2] 0].split("/")[2]
else: else:
@ -178,8 +178,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
return self.call_action() return self.call_action()
def call_action(self): def _get_action(self):
headers = self.response_headers
action = self.querystring.get('Action', [""])[0] action = self.querystring.get('Action', [""])[0]
if not action: # Some services use a header for the action if not action: # Some services use a header for the action
# Headers are case-insensitive. Probably a better way to do this. # Headers are case-insensitive. Probably a better way to do this.
@ -188,7 +187,11 @@ class BaseResponse(_TemplateEnvironmentMixin):
if match: if match:
action = match.split(".")[-1] action = match.split(".")[-1]
action = camelcase_to_underscores(action) return action
def call_action(self):
headers = self.response_headers
action = camelcase_to_underscores(self._get_action())
method_names = method_names_from_class(self.__class__) method_names = method_names_from_class(self.__class__)
if action in method_names: if action in method_names:
method = getattr(self, action) method = getattr(self, action)
@ -196,10 +199,14 @@ class BaseResponse(_TemplateEnvironmentMixin):
response = method() response = method()
except HTTPException as http_error: except HTTPException as http_error:
response = http_error.description, dict(status=http_error.code) response = http_error.description, dict(status=http_error.code)
if isinstance(response, six.string_types): if isinstance(response, six.string_types):
return 200, headers, response return 200, headers, response
else: else:
body, new_headers = response if len(response) == 2:
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200) status = new_headers.get('status', 200)
headers.update(new_headers) headers.update(new_headers)
# Cast status to string # Cast status to string
@ -310,7 +317,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
param_index += 1 param_index += 1
return results return results
def _get_map_prefix(self, param_prefix): def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'):
results = {} results = {}
param_index = 1 param_index = 1
while 1: while 1:
@ -319,9 +326,9 @@ class BaseResponse(_TemplateEnvironmentMixin):
k, v = None, None k, v = None, None
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith(index_prefix): if key.startswith(index_prefix):
if key.endswith('.key'): if key.endswith(key_end):
k = value[0] k = value[0]
elif key.endswith('.value'): elif key.endswith(value_end):
v = value[0] v = value[0]
if not (k and v): if not (k and v):
@ -414,6 +421,9 @@ class _RecursiveDictRef(object):
def __getattr__(self, key): def __getattr__(self, key):
return self.dic.__getattr__(key) return self.dic.__getattr__(key)
def __getitem__(self, key):
return self.dic.__getitem__(key)
def set_reference(self, key, dic): def set_reference(self, key, dic):
"""Set the RecursiveDictRef object to keep reference to dict object """Set the RecursiveDictRef object to keep reference to dict object
(dic) at the key. (dic) at the key.

View File

@ -1,10 +1,16 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from functools import wraps
import binascii
import datetime import datetime
import inspect import inspect
import random import random
import re import re
import six import six
import string
REQUEST_ID_LONG = string.digits + string.ascii_uppercase
def camelcase_to_underscores(argument): def camelcase_to_underscores(argument):
@ -174,11 +180,17 @@ def iso_8601_datetime_without_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z' return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z'
RFC1123 = '%a, %d %b %Y %H:%M:%S GMT'
def rfc_1123_datetime(datetime): def rfc_1123_datetime(datetime):
RFC1123 = '%a, %d %b %Y %H:%M:%S GMT'
return datetime.strftime(RFC1123) return datetime.strftime(RFC1123)
def str_to_rfc_1123_datetime(str):
return datetime.datetime.strptime(str, RFC1123)
def unix_time(dt=None): def unix_time(dt=None):
dt = dt or datetime.datetime.utcnow() dt = dt or datetime.datetime.utcnow()
epoch = datetime.datetime.utcfromtimestamp(0) epoch = datetime.datetime.utcfromtimestamp(0)
@ -188,3 +200,90 @@ def unix_time(dt=None):
def unix_time_millis(dt=None): def unix_time_millis(dt=None):
return unix_time(dt) * 1000.0 return unix_time(dt) * 1000.0
def gen_amz_crc32(response, headerdict=None):
if not isinstance(response, bytes):
response = response.encode()
crc = str(binascii.crc32(response))
if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amz-crc32': crc})
return crc
def gen_amzn_requestid_long(headerdict=None):
req_id = ''.join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amzn-requestid': req_id})
return req_id
def amz_crc32(f):
@wraps(f)
def _wrapper(*args, **kwargs):
response = f(*args, **kwargs)
headers = {}
status = 200
if isinstance(response, six.string_types):
body = response
else:
if len(response) == 2:
body, new_headers = response
status = new_headers.get('status', 200)
else:
status, new_headers, body = response
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
try:
# Doesnt work on python2 for some odd unicode strings
gen_amz_crc32(body, headers)
except Exception:
pass
return status, headers, body
return _wrapper
def amzn_request_id(f):
@wraps(f)
def _wrapper(*args, **kwargs):
response = f(*args, **kwargs)
headers = {}
status = 200
if isinstance(response, six.string_types):
body = response
else:
if len(response) == 2:
body, new_headers = response
status = new_headers.get('status', 200)
else:
status, new_headers, body = response
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
request_id = gen_amzn_requestid_long(headers)
# Update request ID in XML
try:
body = body.replace('{{ requestid }}', request_id)
except Exception: # Will just ignore if it cant work on bytes (which are str's on python2)
pass
return status, headers, body
return _wrapper

View File

@ -137,6 +137,20 @@ class Table(BaseModel):
} }
return results return results
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
key_attr = [i['AttributeName'] for i in properties['KeySchema'] if i['KeyType'] == 'HASH'][0]
key_type = [i['AttributeType'] for i in properties['AttributeDefinitions'] if i['AttributeName'] == key_attr][0]
spec = {
'name': properties['TableName'],
'hash_key_attr': key_attr,
'hash_key_type': key_type
}
# TODO: optional properties still missing:
# range_key_attr, range_key_type, read_capacity, write_capacity
return Table(**spec)
def __len__(self): def __len__(self):
count = 0 count = 0
for key, value in self.items.items(): for key, value in self.items.items():
@ -245,6 +259,14 @@ class Table(BaseModel):
except KeyError: except KeyError:
return None return None
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'StreamArn':
region = 'us-east-1'
time = '2000-01-01T00:00:00.000'
return 'arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}'.format(region, self.name, time)
raise UnformattedGetAttTemplateException()
class DynamoDBBackend(BaseBackend): class DynamoDBBackend(BaseBackend):

View File

@ -7,33 +7,6 @@ from moto.core.utils import camelcase_to_underscores
from .models import dynamodb_backend, dynamo_json_dump from .models import dynamodb_backend, dynamo_json_dump
GET_SESSION_TOKEN_RESULT = """
<GetSessionTokenResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetSessionTokenResult>
<Credentials>
<SessionToken>
AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/L
To6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3z
rkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtp
Z3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE
</SessionToken>
<SecretAccessKey>
wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY
</SecretAccessKey>
<Expiration>2011-07-11T19:55:29.611Z</Expiration>
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId>
</Credentials>
</GetSessionTokenResult>
<ResponseMetadata>
<RequestId>58c5dbae-abef-11e0-8cfe-09039844ac7d</RequestId>
</ResponseMetadata>
</GetSessionTokenResponse>"""
def sts_handler():
return GET_SESSION_TOKEN_RESULT
class DynamoHandler(BaseResponse): class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers): def get_endpoint_name(self, headers):
@ -51,11 +24,7 @@ class DynamoHandler(BaseResponse):
return status, self.response_headers, dynamo_json_dump({'__type': type_}) return status, self.response_headers, dynamo_json_dump({'__type': type_})
def call_action(self): def call_action(self):
body = self.body self.body = json.loads(self.body or '{}')
if 'GetSessionToken' in body:
return 200, self.response_headers, sts_handler()
self.body = json.loads(body or '{}')
endpoint = self.get_endpoint_name(self.headers) endpoint = self.get_endpoint_name(self.headers)
if endpoint: if endpoint:
endpoint = camelcase_to_underscores(endpoint) endpoint = camelcase_to_underscores(endpoint)

View File

@ -2,8 +2,7 @@ from __future__ import unicode_literals
from .responses import DynamoHandler from .responses import DynamoHandler
url_bases = [ url_bases = [
"https?://dynamodb.(.+).amazonaws.com", "https?://dynamodb.(.+).amazonaws.com"
"https?://sts.amazonaws.com",
] ]
url_paths = { url_paths = {

View File

@ -1,4 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re
import six
# TODO add tests for all of these # TODO add tests for all of these
EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa
@ -39,3 +41,490 @@ COMPARISON_FUNCS = {
def get_comparison_func(range_comparison): def get_comparison_func(range_comparison):
return COMPARISON_FUNCS.get(range_comparison) return COMPARISON_FUNCS.get(range_comparison)
class RecursionStopIteration(StopIteration):
pass
def get_filter_expression(expr, names, values):
# Examples
# expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)'
# expr = 'Id > 5 AND Subs < 7'
if names is None:
names = {}
if values is None:
values = {}
# Do substitutions
for key, value in names.items():
expr = expr.replace(key, value)
# Store correct types of values for use later
values_map = {}
for key, value in values.items():
if 'N' in value:
values_map[key] = float(value['N'])
elif 'BOOL' in value:
values_map[key] = value['BOOL']
elif 'S' in value:
values_map[key] = value['S']
elif 'NS' in value:
values_map[key] = tuple(value['NS'])
elif 'SS' in value:
values_map[key] = tuple(value['SS'])
elif 'L' in value:
values_map[key] = tuple(value['L'])
else:
raise NotImplementedError()
# Remove all spaces, tbf we could just skip them in the next step.
# The number of known options is really small so we can do a fair bit of cheating
expr = list(expr.strip())
# DodgyTokenisation stage 1
def is_value(val):
return val not in ('<', '>', '=', '(', ')')
def contains_keyword(val):
for kw in ('BETWEEN', 'IN', 'AND', 'OR', 'NOT'):
if kw in val:
return kw
return None
def is_function(val):
return val in ('attribute_exists', 'attribute_not_exists', 'attribute_type', 'begins_with', 'contains', 'size')
# Does the main part of splitting between sections of characters
tokens = []
stack = ''
while len(expr) > 0:
current_char = expr.pop(0)
if current_char == ' ':
if len(stack) > 0:
tokens.append(stack)
stack = ''
elif current_char == ',': # Split params ,
if len(stack) > 0:
tokens.append(stack)
stack = ''
elif is_value(current_char):
stack += current_char
kw = contains_keyword(stack)
if kw is not None:
# We have a kw in the stack, could be AND or something like 5AND
tmp = stack.replace(kw, '')
if len(tmp) > 0:
tokens.append(tmp)
tokens.append(kw)
stack = ''
else:
if len(stack) > 0:
tokens.append(stack)
tokens.append(current_char)
stack = ''
if len(stack) > 0:
tokens.append(stack)
def is_op(val):
return val in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT')
# DodgyTokenisation stage 2, it groups together some elements to make RPN'ing it later easier.
def handle_token(token, tokens2, token_iterator):
# ok so this essentially groups up some tokens to make later parsing easier,
# when it encounters brackets it will recurse and then unrecurse when RecursionStopIteration is raised.
if token == ')':
raise RecursionStopIteration() # Should be recursive so this should work
elif token == '(':
temp_list = []
try:
while True:
next_token = six.next(token_iterator)
handle_token(next_token, temp_list, token_iterator)
except RecursionStopIteration:
pass # Continue
except StopIteration:
ValueError('Malformed filter expression, type1')
# Sigh, we only want to group a tuple if it doesnt contain operators
if any([is_op(item) for item in temp_list]):
# Its an expression
tokens2.append('(')
tokens2.extend(temp_list)
tokens2.append(')')
else:
tokens2.append(tuple(temp_list))
elif token == 'BETWEEN':
field = tokens2.pop()
# if values map contains a number, it would be a float
# so we need to int() it anyway
op1 = six.next(token_iterator)
op1 = int(values_map.get(op1, op1))
and_op = six.next(token_iterator)
assert and_op == 'AND'
op2 = six.next(token_iterator)
op2 = int(values_map.get(op2, op2))
tokens2.append(['between', field, op1, op2])
elif is_function(token):
function_list = [token]
lbracket = six.next(token_iterator)
assert lbracket == '('
next_token = six.next(token_iterator)
while next_token != ')':
function_list.append(next_token)
next_token = six.next(token_iterator)
tokens2.append(function_list)
else:
# Convert tokens back to real types
if token in values_map:
token = values_map[token]
# Need to join >= <= <>
if len(tokens2) > 0 and ((tokens2[-1] == '>' and token == '=') or (tokens2[-1] == '<' and token == '=') or (tokens2[-1] == '<' and token == '>')):
tokens2.append(tokens2.pop() + token)
else:
tokens2.append(token)
tokens2 = []
token_iterator = iter(tokens)
for token in token_iterator:
handle_token(token, tokens2, token_iterator)
# Start of the Shunting-Yard algorithm. <-- Proper beast algorithm!
def is_number(val):
return val not in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT')
OPS = {'<': 5, '>': 5, '=': 5, '>=': 5, '<=': 5, '<>': 5, 'IN': 8, 'AND': 11, 'OR': 12, 'NOT': 10, 'BETWEEN': 9, '(': 100, ')': 100}
def shunting_yard(token_list):
output = []
op_stack = []
# Basically takes in an infix notation calculation, converts it to a reverse polish notation where there is no
# ambiguity on which order operators are applied.
while len(token_list) > 0:
token = token_list.pop(0)
if token == '(':
op_stack.append(token)
elif token == ')':
while len(op_stack) > 0 and op_stack[-1] != '(':
output.append(op_stack.pop())
lbracket = op_stack.pop()
assert lbracket == '('
elif is_number(token):
output.append(token)
else:
# Must be operator kw
# Cheat, NOT is our only RIGHT associative operator, should really have dict of operator associativity
while len(op_stack) > 0 and OPS[op_stack[-1]] <= OPS[token] and op_stack[-1] != 'NOT':
output.append(op_stack.pop())
op_stack.append(token)
while len(op_stack) > 0:
output.append(op_stack.pop())
return output
output = shunting_yard(tokens2)
# Hacky function to convert dynamo functions (which are represented as lists) to their Class equivalent
def to_func(val):
if isinstance(val, list):
func_name = val.pop(0)
# Expand rest of the list to arguments
val = FUNC_CLASS[func_name](*val)
return val
# Simple reverse polish notation execution. Builts up a nested filter object.
# The filter object then takes a dynamo item and returns true/false
stack = []
for token in output:
if is_op(token):
op_cls = OP_CLASS[token]
if token == 'NOT':
op1 = stack.pop()
op2 = True
else:
op2 = stack.pop()
op1 = stack.pop()
stack.append(op_cls(op1, op2))
else:
stack.append(to_func(token))
result = stack.pop(0)
if len(stack) > 0:
raise ValueError('Malformed filter expression, type2')
return result
class Op(object):
"""
Base class for a FilterExpression operator
"""
OP = ''
def __init__(self, lhs, rhs):
self.lhs = lhs
self.rhs = rhs
def _lhs(self, item):
"""
:type item: moto.dynamodb2.models.Item
"""
lhs = self.lhs
if isinstance(self.lhs, (Op, Func)):
lhs = self.lhs.expr(item)
elif isinstance(self.lhs, six.string_types):
try:
lhs = item.attrs[self.lhs].cast_value
except Exception:
pass
return lhs
def _rhs(self, item):
rhs = self.rhs
if isinstance(self.rhs, (Op, Func)):
rhs = self.rhs.expr(item)
elif isinstance(self.rhs, six.string_types):
try:
rhs = item.attrs[self.rhs].cast_value
except Exception:
pass
return rhs
def expr(self, item):
return True
def __repr__(self):
return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs)
class Func(object):
"""
Base class for a FilterExpression function
"""
FUNC = 'Unknown'
def expr(self, item):
return True
def __repr__(self):
return 'Func(...)'.format(self.FUNC)
class OpNot(Op):
OP = 'NOT'
def expr(self, item):
lhs = self._lhs(item)
return not lhs
def __str__(self):
return '({0} {1})'.format(self.OP, self.lhs)
class OpAnd(Op):
OP = 'AND'
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs and rhs
class OpLessThan(Op):
OP = '<'
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs < rhs
class OpGreaterThan(Op):
OP = '>'
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs > rhs
class OpEqual(Op):
OP = '='
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs == rhs
class OpNotEqual(Op):
OP = '<>'
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs == rhs
class OpLessThanOrEqual(Op):
OP = '<='
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs <= rhs
class OpGreaterThanOrEqual(Op):
OP = '>='
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs >= rhs
class OpOr(Op):
OP = 'OR'
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs or rhs
class OpIn(Op):
OP = 'IN'
def expr(self, item):
lhs = self._lhs(item)
rhs = self._rhs(item)
return lhs in rhs
class FuncAttrExists(Func):
FUNC = 'attribute_exists'
def __init__(self, attribute):
self.attr = attribute
def expr(self, item):
return self.attr in item.attrs
class FuncAttrNotExists(Func):
FUNC = 'attribute_not_exists'
def __init__(self, attribute):
self.attr = attribute
def expr(self, item):
return self.attr not in item.attrs
class FuncAttrType(Func):
FUNC = 'attribute_type'
def __init__(self, attribute, _type):
self.attr = attribute
self.type = _type
def expr(self, item):
return self.attr in item.attrs and item.attrs[self.attr].type == self.type
class FuncBeginsWith(Func):
FUNC = 'begins_with'
def __init__(self, attribute, substr):
self.attr = attribute
self.substr = substr
def expr(self, item):
return self.attr in item.attrs and item.attrs[self.attr].type == 'S' and item.attrs[self.attr].value.startswith(self.substr)
class FuncContains(Func):
FUNC = 'contains'
def __init__(self, attribute, operand):
self.attr = attribute
self.operand = operand
def expr(self, item):
if self.attr not in item.attrs:
return False
if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'BS', 'L', 'M'):
return self.operand in item.attrs[self.attr].value
return False
class FuncSize(Func):
FUNC = 'contains'
def __init__(self, attribute):
self.attr = attribute
def expr(self, item):
if self.attr not in item.attrs:
raise ValueError('Invalid attribute name {0}'.format(self.attr))
if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'):
return len(item.attrs[self.attr].value)
raise ValueError('Invalid filter expression')
class FuncBetween(Func):
FUNC = 'between'
def __init__(self, attribute, start, end):
self.attr = attribute
self.start = start
self.end = end
def expr(self, item):
if self.attr not in item.attrs:
raise ValueError('Invalid attribute name {0}'.format(self.attr))
return self.start <= item.attrs[self.attr].cast_value <= self.end
OP_CLASS = {
'NOT': OpNot,
'AND': OpAnd,
'OR': OpOr,
'IN': OpIn,
'<': OpLessThan,
'>': OpGreaterThan,
'<=': OpLessThanOrEqual,
'>=': OpGreaterThanOrEqual,
'=': OpEqual,
'<>': OpNotEqual
}
FUNC_CLASS = {
'attribute_exists': FuncAttrExists,
'attribute_not_exists': FuncAttrNotExists,
'attribute_type': FuncAttrType,
'begins_with': FuncBeginsWith,
'contains': FuncContains,
'size': FuncSize,
'between': FuncBetween
}

View File

@ -3,11 +3,12 @@ from collections import defaultdict
import datetime import datetime
import decimal import decimal
import json import json
import re
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
from .comparisons import get_comparison_func from .comparisons import get_comparison_func, get_filter_expression, Op
class DynamoJsonEncoder(json.JSONEncoder): class DynamoJsonEncoder(json.JSONEncoder):
@ -56,7 +57,7 @@ class DynamoType(object):
@property @property
def cast_value(self): def cast_value(self):
if self.type == 'N': if self.is_number():
try: try:
return int(self.value) return int(self.value)
except ValueError: except ValueError:
@ -75,6 +76,15 @@ class DynamoType(object):
comparison_func = get_comparison_func(range_comparison) comparison_func = get_comparison_func(range_comparison)
return comparison_func(self.cast_value, *range_values) return comparison_func(self.cast_value, *range_values)
def is_number(self):
return self.type == 'N'
def is_set(self):
return self.type == 'SS' or self.type == 'NS' or self.type == 'BS'
def same_type(self, other):
return self.type == other.type
class Item(BaseModel): class Item(BaseModel):
@ -115,28 +125,81 @@ class Item(BaseModel):
} }
def update(self, update_expression, expression_attribute_names, expression_attribute_values): def update(self, update_expression, expression_attribute_names, expression_attribute_values):
ACTION_VALUES = ['SET', 'set', 'REMOVE', 'remove'] # Update subexpressions are identifiable by the operator keyword, so split on that and
# get rid of the empty leading string.
action = None parts = [p for p in re.split(r'\b(SET|REMOVE|ADD|DELETE)\b', update_expression, flags=re.I) if p]
for value in update_expression.split(): # make sure that we correctly found only operator/value pairs
if value in ACTION_VALUES: assert len(parts) % 2 == 0, "Mismatched operators and values in update expression: '{}'".format(update_expression)
# An action for action, valstr in zip(parts[:-1:2], parts[1::2]):
action = value action = action.upper()
continue values = valstr.split(',')
else: for value in values:
# A Real value # A Real value
value = value.lstrip(":").rstrip(",") value = value.lstrip(":").rstrip(",").strip()
for k, v in expression_attribute_names.items(): for k, v in expression_attribute_names.items():
value = value.replace(k, v) value = re.sub(r'{0}\b'.format(k), v, value)
if action == "REMOVE" or action == 'remove':
self.attrs.pop(value, None) if action == "REMOVE":
elif action == 'SET' or action == 'set': self.attrs.pop(value, None)
key, value = value.split("=") elif action == 'SET':
if value in expression_attribute_values: key, value = value.split("=")
self.attrs[key] = DynamoType( key = key.strip()
expression_attribute_values[value]) value = value.strip()
if value in expression_attribute_values:
self.attrs[key] = DynamoType(expression_attribute_values[value])
else:
self.attrs[key] = DynamoType({"S": value})
elif action == 'ADD':
key, value = value.split(" ", 1)
key = key.strip()
value_str = value.strip()
if value_str in expression_attribute_values:
dyn_value = DynamoType(expression_attribute_values[value])
else:
raise TypeError
# Handle adding numbers - value gets added to existing value,
# or added to 0 if it doesn't exist yet
if dyn_value.is_number():
existing = self.attrs.get(key, DynamoType({"N": '0'}))
if not existing.same_type(dyn_value):
raise TypeError()
self.attrs[key] = DynamoType({"N": str(
decimal.Decimal(existing.value) +
decimal.Decimal(dyn_value.value)
)})
# Handle adding sets - value is added to the set, or set is
# created with only this value if it doesn't exist yet
# New value must be of same set type as previous value
elif dyn_value.is_set():
existing = self.attrs.get(key, DynamoType({dyn_value.type: {}}))
if not existing.same_type(dyn_value):
raise TypeError()
new_set = set(existing.value).union(dyn_value.value)
self.attrs[key] = DynamoType({existing.type: list(new_set)})
else: # Number and Sets are the only supported types for ADD
raise TypeError
elif action == 'DELETE':
key, value = value.split(" ", 1)
key = key.strip()
value_str = value.strip()
if value_str in expression_attribute_values:
dyn_value = DynamoType(expression_attribute_values[value])
else:
raise TypeError
if not dyn_value.is_set():
raise TypeError
existing = self.attrs.get(key, None)
if existing:
if not existing.same_type(dyn_value):
raise TypeError
new_set = set(existing.value).difference(dyn_value.value)
self.attrs[key] = DynamoType({existing.type: list(new_set)})
else: else:
self.attrs[key] = DynamoType({"S": value}) raise NotImplementedError('{} update action not yet supported'.format(action))
def update_with_attribute_updates(self, attribute_updates): def update_with_attribute_updates(self, attribute_updates):
for attribute_name, update_action in attribute_updates.items(): for attribute_name, update_action in attribute_updates.items():
@ -167,6 +230,12 @@ class Item(BaseModel):
decimal.Decimal(existing.value) + decimal.Decimal(existing.value) +
decimal.Decimal(new_value) decimal.Decimal(new_value)
)}) )})
elif set(update_action['Value'].keys()) == set(['SS']):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
new_set = set(existing.value).union(set(new_value))
self.attrs[attribute_name] = DynamoType({
"SS": list(new_set)
})
else: else:
# TODO: implement other data types # TODO: implement other data types
raise NotImplementedError( raise NotImplementedError(
@ -343,9 +412,9 @@ class Table(BaseModel):
return None return None
def query(self, hash_key, range_comparison, range_objs, limit, def query(self, hash_key, range_comparison, range_objs, limit,
exclusive_start_key, scan_index_forward, index_name=None, **filter_kwargs): exclusive_start_key, scan_index_forward, projection_expression,
index_name=None, **filter_kwargs):
results = [] results = []
if index_name: if index_name:
all_indexes = (self.global_indexes or []) + (self.indexes or []) all_indexes = (self.global_indexes or []) + (self.indexes or [])
indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) indexes_by_name = dict((i['IndexName'], i) for i in all_indexes)
@ -415,6 +484,13 @@ class Table(BaseModel):
else: else:
results.sort(key=lambda item: item.range_key) results.sort(key=lambda item: item.range_key)
if projection_expression:
expressions = [x.strip() for x in projection_expression.split(',')]
for result in possible_results:
for attr in list(result.attrs):
if attr not in expressions:
result.attrs.pop(attr)
if scan_index_forward is False: if scan_index_forward is False:
results.reverse() results.reverse()
@ -432,15 +508,15 @@ class Table(BaseModel):
else: else:
yield hash_set yield hash_set
def scan(self, filters, limit, exclusive_start_key): def scan(self, filters, limit, exclusive_start_key, filter_expression=None):
results = [] results = []
scanned_count = 0 scanned_count = 0
for result in self.all_items(): for item in self.all_items():
scanned_count += 1 scanned_count += 1
passes_all_conditions = True passes_all_conditions = True
for attribute_name, (comparison_operator, comparison_objs) in filters.items(): for attribute_name, (comparison_operator, comparison_objs) in filters.items():
attribute = result.attrs.get(attribute_name) attribute = item.attrs.get(attribute_name)
if attribute: if attribute:
# Attribute found # Attribute found
@ -456,8 +532,11 @@ class Table(BaseModel):
passes_all_conditions = False passes_all_conditions = False
break break
if filter_expression is not None:
passes_all_conditions &= filter_expression.expr(item)
if passes_all_conditions: if passes_all_conditions:
results.append(result) results.append(item)
results, last_evaluated_key = self._trim_results(results, limit, results, last_evaluated_key = self._trim_results(results, limit,
exclusive_start_key) exclusive_start_key)
@ -610,7 +689,7 @@ class DynamoDBBackend(BaseBackend):
return table.get_item(hash_key, range_key) return table.get_item(hash_key, range_key)
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts, def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts,
limit, exclusive_start_key, scan_index_forward, index_name=None, **filter_kwargs): limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, **filter_kwargs):
table = self.tables.get(table_name) table = self.tables.get(table_name)
if not table: if not table:
return None, None return None, None
@ -620,9 +699,9 @@ class DynamoDBBackend(BaseBackend):
for range_value in range_value_dicts] for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values, limit, return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, index_name, **filter_kwargs) exclusive_start_key, scan_index_forward, projection_expression, index_name, **filter_kwargs)
def scan(self, table_name, filters, limit, exclusive_start_key): def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values):
table = self.tables.get(table_name) table = self.tables.get(table_name)
if not table: if not table:
return None, None, None return None, None, None
@ -632,9 +711,15 @@ class DynamoDBBackend(BaseBackend):
dynamo_types = [DynamoType(value) for value in comparison_values] dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types) scan_filters[key] = (comparison_operator, dynamo_types)
return table.scan(scan_filters, limit, exclusive_start_key) if filter_expression is not None:
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values): return table.scan(scan_filters, limit, exclusive_start_key, filter_expression)
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected=None):
table = self.get_table(table_name) table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]): if all([table.hash_key_attr in key, table.range_key_attr in key]):
@ -652,6 +737,34 @@ class DynamoDBBackend(BaseBackend):
range_value = None range_value = None
item = table.get_item(hash_value, range_value) item = table.get_item(hash_value, range_value)
if item is None:
item_attr = {}
elif hasattr(item, 'attrs'):
item_attr = item.attrs
else:
item_attr = item
if not expected:
expected = {}
for key, val in expected.items():
if 'Exists' in val and val['Exists'] is False:
if key in item_attr:
raise ValueError("The conditional request failed")
elif key not in item_attr:
raise ValueError("The conditional request failed")
elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value:
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
comparison_func = get_comparison_func(
val['ComparisonOperator'])
dynamo_types = [DynamoType(ele) for ele in val[
"AttributeValueList"]]
for t in dynamo_types:
if not comparison_func(item_attr[key].value, t.value):
raise ValueError('The conditional request failed')
# Update does not fail on new items, so create one # Update does not fail on new items, so create one
if item is None: if item is None:
data = { data = {

View File

@ -4,37 +4,10 @@ import six
import re import re
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores, amzn_request_id
from .models import dynamodb_backend2, dynamo_json_dump from .models import dynamodb_backend2, dynamo_json_dump
GET_SESSION_TOKEN_RESULT = """
<GetSessionTokenResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetSessionTokenResult>
<Credentials>
<SessionToken>
AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/L
To6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3z
rkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtp
Z3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE
</SessionToken>
<SecretAccessKey>
wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY
</SecretAccessKey>
<Expiration>2011-07-11T19:55:29.611Z</Expiration>
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId>
</Credentials>
</GetSessionTokenResult>
<ResponseMetadata>
<RequestId>58c5dbae-abef-11e0-8cfe-09039844ac7d</RequestId>
</ResponseMetadata>
</GetSessionTokenResponse>"""
def sts_handler():
return GET_SESSION_TOKEN_RESULT
class DynamoHandler(BaseResponse): class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers): def get_endpoint_name(self, headers):
@ -48,15 +21,12 @@ class DynamoHandler(BaseResponse):
if match: if match:
return match.split(".")[1] return match.split(".")[1]
def error(self, type_, status=400): def error(self, type_, message, status=400):
return status, self.response_headers, dynamo_json_dump({'__type': type_}) return status, self.response_headers, dynamo_json_dump({'__type': type_, 'message': message})
@amzn_request_id
def call_action(self): def call_action(self):
body = self.body self.body = json.loads(self.body or '{}')
if 'GetSessionToken' in body:
return 200, self.response_headers, sts_handler()
self.body = json.loads(body or '{}')
endpoint = self.get_endpoint_name(self.headers) endpoint = self.get_endpoint_name(self.headers)
if endpoint: if endpoint:
endpoint = camelcase_to_underscores(endpoint) endpoint = camelcase_to_underscores(endpoint)
@ -87,6 +57,7 @@ class DynamoHandler(BaseResponse):
response = {"TableNames": tables} response = {"TableNames": tables}
if limit and len(all_tables) > start + limit: if limit and len(all_tables) > start + limit:
response["LastEvaluatedTableName"] = tables[-1] response["LastEvaluatedTableName"] = tables[-1]
return dynamo_json_dump(response) return dynamo_json_dump(response)
def create_table(self): def create_table(self):
@ -113,7 +84,7 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(table.describe()) return dynamo_json_dump(table.describe())
else: else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceInUseException' er = 'com.amazonaws.dynamodb.v20111205#ResourceInUseException'
return self.error(er) return self.error(er, 'Resource in use')
def delete_table(self): def delete_table(self):
name = self.body['TableName'] name = self.body['TableName']
@ -122,7 +93,7 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(table.describe()) return dynamo_json_dump(table.describe())
else: else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er) return self.error(er, 'Requested resource not found')
def tag_resource(self): def tag_resource(self):
tags = self.body['Tags'] tags = self.body['Tags']
@ -151,7 +122,7 @@ class DynamoHandler(BaseResponse):
return json.dumps({'Tags': tags_resp}) return json.dumps({'Tags': tags_resp})
except AttributeError: except AttributeError:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er) return self.error(er, 'Requested resource not found')
def update_table(self): def update_table(self):
name = self.body['TableName'] name = self.body['TableName']
@ -169,12 +140,24 @@ class DynamoHandler(BaseResponse):
table = dynamodb_backend2.tables[name] table = dynamodb_backend2.tables[name]
except KeyError: except KeyError:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er) return self.error(er, 'Requested resource not found')
return dynamo_json_dump(table.describe(base_key='Table')) return dynamo_json_dump(table.describe(base_key='Table'))
def put_item(self): def put_item(self):
name = self.body['TableName'] name = self.body['TableName']
item = self.body['Item'] item = self.body['Item']
res = re.search('\"\"', json.dumps(item))
if res:
er = 'com.amazonaws.dynamodb.v20111205#ValidationException'
return (400,
{'server': 'amazon.com'},
dynamo_json_dump({'__type': er,
'message': ('One or more parameter values were '
'invalid: An AttributeValue may not '
'contain an empty string')}
))
overwrite = 'Expected' not in self.body overwrite = 'Expected' not in self.body
if not overwrite: if not overwrite:
expected = self.body['Expected'] expected = self.body['Expected']
@ -207,17 +190,20 @@ class DynamoHandler(BaseResponse):
try: try:
result = dynamodb_backend2.put_item( result = dynamodb_backend2.put_item(
name, item, expected, overwrite) name, item, expected, overwrite)
except Exception: except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er) return self.error(er, 'A condition specified in the operation could not be evaluated.')
if result: if result:
item_dict = result.to_json() item_dict = result.to_json()
item_dict['ConsumedCapacityUnits'] = 1 item_dict['ConsumedCapacity'] = {
'TableName': name,
'CapacityUnits': 1
}
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
else: else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er) return self.error(er, 'Requested resource not found')
def batch_write_item(self): def batch_write_item(self):
table_batches = self.body['RequestItems'] table_batches = self.body['RequestItems']
@ -254,15 +240,17 @@ class DynamoHandler(BaseResponse):
item = dynamodb_backend2.get_item(name, key) item = dynamodb_backend2.get_item(name, key)
except ValueError: except ValueError:
er = 'com.amazon.coral.validate#ValidationException' er = 'com.amazon.coral.validate#ValidationException'
return self.error(er, status=400) return self.error(er, 'Validation Exception')
if item: if item:
item_dict = item.describe_attrs(attributes=None) item_dict = item.describe_attrs(attributes=None)
item_dict['ConsumedCapacityUnits'] = 0.5 item_dict['ConsumedCapacity'] = {
'TableName': name,
'CapacityUnits': 0.5
}
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
else: else:
# Item not found # Item not found
er = '{}' return 200, self.response_headers, '{}'
return self.error(er, status=200)
def batch_get_item(self): def batch_get_item(self):
table_batches = self.body['RequestItems'] table_batches = self.body['RequestItems']
@ -296,11 +284,26 @@ class DynamoHandler(BaseResponse):
name = self.body['TableName'] name = self.body['TableName']
# {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}} # {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}}
key_condition_expression = self.body.get('KeyConditionExpression') key_condition_expression = self.body.get('KeyConditionExpression')
projection_expression = self.body.get('ProjectionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames')
if projection_expression and expression_attribute_names:
expressions = [x.strip() for x in projection_expression.split(',')]
for expression in expressions:
if expression in expression_attribute_names:
projection_expression = projection_expression.replace(expression, expression_attribute_names[expression])
filter_kwargs = {} filter_kwargs = {}
if key_condition_expression: if key_condition_expression:
value_alias_map = self.body['ExpressionAttributeValues'] value_alias_map = self.body['ExpressionAttributeValues']
table = dynamodb_backend2.get_table(name) table = dynamodb_backend2.get_table(name)
# If table does not exist
if table is None:
return self.error('com.amazonaws.dynamodb.v20120810#ResourceNotFoundException',
'Requested resource not found')
index_name = self.body.get('IndexName') index_name = self.body.get('IndexName')
if index_name: if index_name:
all_indexes = (table.global_indexes or []) + \ all_indexes = (table.global_indexes or []) + \
@ -316,24 +319,26 @@ class DynamoHandler(BaseResponse):
else: else:
index = table.schema index = table.schema
key_map = [column for _, column in sorted( reverse_attribute_lookup = dict((v, k) for k, v in
(k, v) for k, v in self.body['ExpressionAttributeNames'].items())] six.iteritems(self.body['ExpressionAttributeNames']))
if " AND " in key_condition_expression: if " AND " in key_condition_expression:
expressions = key_condition_expression.split(" AND ", 1) expressions = key_condition_expression.split(" AND ", 1)
index_hash_key = [ index_hash_key = [key for key in index if key['KeyType'] == 'HASH'][0]
key for key in index if key['KeyType'] == 'HASH'][0] hash_key_var = reverse_attribute_lookup.get(index_hash_key['AttributeName'],
hash_key_index_in_key_map = key_map.index( index_hash_key['AttributeName'])
index_hash_key['AttributeName']) hash_key_regex = r'(^|[\s(]){0}\b'.format(hash_key_var)
i, hash_key_expression = next((i, e) for i, e in enumerate(expressions)
if re.search(hash_key_regex, e))
hash_key_expression = hash_key_expression.strip('()')
expressions.pop(i)
hash_key_expression = expressions.pop( # TODO implement more than one range expression and OR operators
hash_key_index_in_key_map).strip('()')
# TODO implement more than one range expression and OR
# operators
range_key_expression = expressions[0].strip('()') range_key_expression = expressions[0].strip('()')
range_key_expression_components = range_key_expression.split() range_key_expression_components = range_key_expression.split()
range_comparison = range_key_expression_components[1] range_comparison = range_key_expression_components[1]
if 'AND' in range_key_expression: if 'AND' in range_key_expression:
range_comparison = 'BETWEEN' range_comparison = 'BETWEEN'
range_values = [ range_values = [
@ -367,7 +372,7 @@ class DynamoHandler(BaseResponse):
filter_kwargs[key] = value filter_kwargs[key] = value
if hash_key_name is None: if hash_key_name is None:
er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException" er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException"
return self.error(er) return self.error(er, 'Requested resource not found')
hash_key = key_conditions[hash_key_name][ hash_key = key_conditions[hash_key_name][
'AttributeValueList'][0] 'AttributeValueList'][0]
if len(key_conditions) == 1: if len(key_conditions) == 1:
@ -376,7 +381,7 @@ class DynamoHandler(BaseResponse):
else: else:
if range_key_name is None and not filter_kwargs: if range_key_name is None and not filter_kwargs:
er = "com.amazon.coral.validate#ValidationException" er = "com.amazon.coral.validate#ValidationException"
return self.error(er) return self.error(er, 'Validation Exception')
else: else:
range_condition = key_conditions.get(range_key_name) range_condition = key_conditions.get(range_key_name)
if range_condition: if range_condition:
@ -395,16 +400,20 @@ class DynamoHandler(BaseResponse):
scan_index_forward = self.body.get("ScanIndexForward") scan_index_forward = self.body.get("ScanIndexForward")
items, scanned_count, last_evaluated_key = dynamodb_backend2.query( items, scanned_count, last_evaluated_key = dynamodb_backend2.query(
name, hash_key, range_comparison, range_values, limit, name, hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, index_name=index_name, **filter_kwargs) exclusive_start_key, scan_index_forward, projection_expression, index_name=index_name, **filter_kwargs)
if items is None: if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er) return self.error(er, 'Requested resource not found')
result = { result = {
"Count": len(items), "Count": len(items),
"ConsumedCapacityUnits": 1, 'ConsumedCapacity': {
'TableName': name,
'CapacityUnits': 1,
},
"ScannedCount": scanned_count "ScannedCount": scanned_count
} }
if self.body.get('Select', '').upper() != 'COUNT': if self.body.get('Select', '').upper() != 'COUNT':
result["Items"] = [item.attrs for item in items] result["Items"] = [item.attrs for item in items]
@ -425,21 +434,40 @@ class DynamoHandler(BaseResponse):
comparison_values = scan_filter.get("AttributeValueList", []) comparison_values = scan_filter.get("AttributeValueList", [])
filters[attribute_name] = (comparison_operator, comparison_values) filters[attribute_name] = (comparison_operator, comparison_values)
filter_expression = self.body.get('FilterExpression')
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
exclusive_start_key = self.body.get('ExclusiveStartKey') exclusive_start_key = self.body.get('ExclusiveStartKey')
limit = self.body.get("Limit") limit = self.body.get("Limit")
items, scanned_count, last_evaluated_key = dynamodb_backend2.scan(name, filters, try:
limit, items, scanned_count, last_evaluated_key = dynamodb_backend2.scan(name, filters,
exclusive_start_key) limit,
exclusive_start_key,
filter_expression,
expression_attribute_names,
expression_attribute_values)
except ValueError as err:
er = 'com.amazonaws.dynamodb.v20111205#ValidationError'
return self.error(er, 'Bad Filter Expression: {0}'.format(err))
except Exception as err:
er = 'com.amazonaws.dynamodb.v20111205#InternalFailure'
return self.error(er, 'Internal error. {0}'.format(err))
# Items should be a list, at least an empty one. Is None if table does not exist.
# Should really check this at the beginning
if items is None: if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er) return self.error(er, 'Requested resource not found')
result = { result = {
"Count": len(items), "Count": len(items),
"Items": [item.attrs for item in items], "Items": [item.attrs for item in items],
"ConsumedCapacityUnits": 1, 'ConsumedCapacity': {
'TableName': name,
'CapacityUnits': 1,
},
"ScannedCount": scanned_count "ScannedCount": scanned_count
} }
if last_evaluated_key is not None: if last_evaluated_key is not None:
@ -453,7 +481,7 @@ class DynamoHandler(BaseResponse):
table = dynamodb_backend2.get_table(name) table = dynamodb_backend2.get_table(name)
if not table: if not table:
er = 'com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException' er = 'com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException'
return self.error(er) return self.error(er, 'A condition specified in the operation could not be evaluated.')
item = dynamodb_backend2.delete_item(name, keys) item = dynamodb_backend2.delete_item(name, keys)
if item and return_values == 'ALL_OLD': if item and return_values == 'ALL_OLD':
@ -474,17 +502,55 @@ class DynamoHandler(BaseResponse):
'ExpressionAttributeValues', {}) 'ExpressionAttributeValues', {})
existing_item = dynamodb_backend2.get_item(name, key) existing_item = dynamodb_backend2.get_item(name, key)
if 'Expected' in self.body:
expected = self.body['Expected']
else:
expected = None
# Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected:
condition_expression = self.body.get('ConditionExpression')
if condition_expression and 'OR' not in condition_expression:
cond_items = [c.strip()
for c in condition_expression.split('AND')]
if cond_items:
expected = {}
exists_re = re.compile('^attribute_exists\((.*)\)$')
not_exists_re = re.compile(
'^attribute_not_exists\((.*)\)$')
for cond in cond_items:
exists_m = exists_re.match(cond)
not_exists_m = not_exists_re.match(cond)
if exists_m:
expected[exists_m.group(1)] = {'Exists': True}
elif not_exists_m:
expected[not_exists_m.group(1)] = {'Exists': False}
# Support spaces between operators in an update expression # Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c` # E.g. `a = b + c` -> `a=b+c`
if update_expression: if update_expression:
update_expression = re.sub( update_expression = re.sub(
'\s*([=\+-])\s*', '\\1', update_expression) '\s*([=\+-])\s*', '\\1', update_expression)
item = dynamodb_backend2.update_item( try:
name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values) item = dynamodb_backend2.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values,
expected)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er, 'A condition specified in the operation could not be evaluated.')
except TypeError:
er = 'com.amazonaws.dynamodb.v20111205#ValidationException'
return self.error(er, 'Validation Exception')
item_dict = item.to_json() item_dict = item.to_json()
item_dict['ConsumedCapacityUnits'] = 0.5 item_dict['ConsumedCapacity'] = {
'TableName': name,
'CapacityUnits': 0.5
}
if not existing_item: if not existing_item:
item_dict['Attributes'] = {} item_dict['Attributes'] = {}

View File

@ -2,8 +2,7 @@ from __future__ import unicode_literals
from .responses import DynamoHandler from .responses import DynamoHandler
url_bases = [ url_bases = [
"https?://dynamodb.(.+).amazonaws.com", "https?://dynamodb.(.+).amazonaws.com"
"https?://sts.amazonaws.com",
] ]
url_paths = { url_paths = {

View File

@ -384,3 +384,20 @@ class RulesPerSecurityGroupLimitExceededError(EC2ClientError):
"RulesPerSecurityGroupLimitExceeded", "RulesPerSecurityGroupLimitExceeded",
'The maximum number of rules per security group ' 'The maximum number of rules per security group '
'has been reached.') 'has been reached.')
class MotoNotImplementedError(NotImplementedError):
def __init__(self, blurb):
super(MotoNotImplementedError, self).__init__(
"{0} has not been implemented in Moto yet."
" Feel free to open an issue at"
" https://github.com/spulec/moto/issues".format(blurb))
class FilterNotImplementedError(MotoNotImplementedError):
def __init__(self, filter_name, method_name):
super(FilterNotImplementedError, self).__init__(
"The filter '{0}' for {1}".format(
filter_name, method_name))

View File

@ -2,9 +2,13 @@ from __future__ import unicode_literals
import copy import copy
import itertools import itertools
import json
import os
import re import re
import six import six
import boto.ec2
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from boto.ec2.instance import Instance as BotoInstance, Reservation from boto.ec2.instance import Instance as BotoInstance, Reservation
@ -61,6 +65,8 @@ from .exceptions import (
InvalidVpnConnectionIdError, InvalidVpnConnectionIdError,
InvalidCustomerGatewayIdError, InvalidCustomerGatewayIdError,
RulesPerSecurityGroupLimitExceededError, RulesPerSecurityGroupLimitExceededError,
MotoNotImplementedError,
FilterNotImplementedError
) )
from .utils import ( from .utils import (
EC2_RESOURCE_TO_PREFIX, EC2_RESOURCE_TO_PREFIX,
@ -104,8 +110,12 @@ from .utils import (
random_vpn_connection_id, random_vpn_connection_id,
random_customer_gateway_id, random_customer_gateway_id,
is_tag_filter, is_tag_filter,
tag_filter_matches,
) )
RESOURCES_DIR = os.path.join(os.path.dirname(__file__), 'resources')
INSTANCE_TYPES = json.load(open(os.path.join(RESOURCES_DIR, 'instance_types.json'), 'r'))
def utc_date_and_time(): def utc_date_and_time():
return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z') return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z')
@ -143,7 +153,7 @@ class TaggedEC2Resource(BaseModel):
for key, value in tag_map.items(): for key, value in tag_map.items():
self.ec2_backend.create_tags([self.id], {key: value}) self.ec2_backend.create_tags([self.id], {key: value})
def get_filter_value(self, filter_name): def get_filter_value(self, filter_name, method_name=None):
tags = self.get_tags() tags = self.get_tags()
if filter_name.startswith('tag:'): if filter_name.startswith('tag:'):
@ -153,12 +163,12 @@ class TaggedEC2Resource(BaseModel):
return tag['value'] return tag['value']
return '' return ''
elif filter_name == 'tag-key':
if filter_name == 'tag-key':
return [tag['key'] for tag in tags] return [tag['key'] for tag in tags]
elif filter_name == 'tag-value':
if filter_name == 'tag-value':
return [tag['value'] for tag in tags] return [tag['value'] for tag in tags]
else:
raise FilterNotImplementedError(filter_name, method_name)
class NetworkInterface(TaggedEC2Resource): class NetworkInterface(TaggedEC2Resource):
@ -260,17 +270,9 @@ class NetworkInterface(TaggedEC2Resource):
return [group.id for group in self._group_set] return [group.id for group in self._group_set]
elif filter_name == 'availability-zone': elif filter_name == 'availability-zone':
return self.subnet.availability_zone return self.subnet.availability_zone
else:
filter_value = super( return super(NetworkInterface, self).get_filter_value(
NetworkInterface, self).get_filter_value(filter_name) filter_name, 'DescribeNetworkInterfaces')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeNetworkInterfaces".format(
filter_name)
)
return filter_value
class NetworkInterfaceBackend(object): class NetworkInterfaceBackend(object):
@ -365,6 +367,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.user_data = user_data self.user_data = user_data
self.security_groups = security_groups self.security_groups = security_groups
self.instance_type = kwargs.get("instance_type", "m1.small") self.instance_type = kwargs.get("instance_type", "m1.small")
self.region_name = kwargs.get("region_name", "us-east-1")
placement = kwargs.get("placement", None) placement = kwargs.get("placement", None)
self.vpc_id = None self.vpc_id = None
self.subnet_id = kwargs.get("subnet_id") self.subnet_id = kwargs.get("subnet_id")
@ -373,6 +376,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.source_dest_check = "true" self.source_dest_check = "true"
self.launch_time = utc_date_and_time() self.launch_time = utc_date_and_time()
self.disable_api_termination = kwargs.get("disable_api_termination", False) self.disable_api_termination = kwargs.get("disable_api_termination", False)
self._spot_fleet_id = kwargs.get("spot_fleet_id", None)
associate_public_ip = kwargs.get("associate_public_ip", False) associate_public_ip = kwargs.get("associate_public_ip", False)
if in_ec2_classic: if in_ec2_classic:
# If we are in EC2-Classic, autoassign a public IP # If we are in EC2-Classic, autoassign a public IP
@ -432,7 +436,11 @@ class Instance(TaggedEC2Resource, BotoInstance):
@property @property
def private_dns(self): def private_dns(self):
return "ip-{0}.ec2.internal".format(self.private_ip) formatted_ip = self.private_ip.replace('.', '-')
if self.region_name == "us-east-1":
return "ip-{0}.ec2.internal".format(formatted_ip)
else:
return "ip-{0}.{1}.compute.internal".format(formatted_ip, self.region_name)
@property @property
def public_ip(self): def public_ip(self):
@ -441,7 +449,11 @@ class Instance(TaggedEC2Resource, BotoInstance):
@property @property
def public_dns(self): def public_dns(self):
if self.public_ip: if self.public_ip:
return "ec2-{0}.compute-1.amazonaws.com".format(self.public_ip) formatted_ip = self.public_ip.replace('.', '-')
if self.region_name == "us-east-1":
return "ec2-{0}.compute-1.amazonaws.com".format(formatted_ip)
else:
return "ec2-{0}.{1}.compute.amazonaws.com".format(formatted_ip, self.region_name)
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -502,6 +514,14 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.teardown_defaults() self.teardown_defaults()
if self._spot_fleet_id:
spot_fleet = self.ec2_backend.get_spot_fleet_request(self._spot_fleet_id)
for spec in spot_fleet.launch_specs:
if spec.instance_type == self.instance_type and spec.subnet_id == self.subnet_id:
break
spot_fleet.fulfilled_capacity -= spec.weighted_capacity
spot_fleet.spot_requests = [req for req in spot_fleet.spot_requests if req.instance != self]
self._state.name = "terminated" self._state.name = "terminated"
self._state.code = 48 self._state.code = 48
@ -580,10 +600,6 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.attach_eni(use_nic, device_index) self.attach_eni(use_nic, device_index)
def set_ip(self, ip_address):
# Should we be creating a new ENI?
self.nics[0].public_ip = ip_address
def attach_eni(self, eni, device_index): def attach_eni(self, eni, device_index):
device_index = int(device_index) device_index = int(device_index)
self.nics[device_index] = eni self.nics[device_index] = eni
@ -785,16 +801,31 @@ class InstanceBackend(object):
return reservations return reservations
class KeyPair(object):
def __init__(self, name, fingerprint, material):
self.name = name
self.fingerprint = fingerprint
self.material = material
def get_filter_value(self, filter_name):
if filter_name == 'key-name':
return self.name
elif filter_name == 'fingerprint':
return self.fingerprint
else:
raise FilterNotImplementedError(filter_name, 'DescribeKeyPairs')
class KeyPairBackend(object): class KeyPairBackend(object):
def __init__(self): def __init__(self):
self.keypairs = defaultdict(dict) self.keypairs = {}
super(KeyPairBackend, self).__init__() super(KeyPairBackend, self).__init__()
def create_key_pair(self, name): def create_key_pair(self, name):
if name in self.keypairs: if name in self.keypairs:
raise InvalidKeyPairDuplicateError(name) raise InvalidKeyPairDuplicateError(name)
self.keypairs[name] = keypair = random_key_pair() keypair = KeyPair(name, **random_key_pair())
keypair['name'] = name self.keypairs[name] = keypair
return keypair return keypair
def delete_key_pair(self, name): def delete_key_pair(self, name):
@ -802,24 +833,27 @@ class KeyPairBackend(object):
self.keypairs.pop(name) self.keypairs.pop(name)
return True return True
def describe_key_pairs(self, filter_names=None): def describe_key_pairs(self, key_names=None, filters=None):
results = [] results = []
for name, keypair in self.keypairs.items(): if key_names:
if not filter_names or name in filter_names: results = [keypair for keypair in self.keypairs.values()
keypair['name'] = name if keypair.name in key_names]
results.append(keypair) if len(key_names) > len(results):
unknown_keys = set(key_names) - set(results)
raise InvalidKeyPairNameError(unknown_keys)
else:
results = self.keypairs.values()
# TODO: Trim error message down to specific invalid name. if filters:
if filter_names and len(filter_names) > len(results): return generic_filter(filters, results)
raise InvalidKeyPairNameError(filter_names) else:
return results
return results
def import_key_pair(self, key_name, public_key_material): def import_key_pair(self, key_name, public_key_material):
if key_name in self.keypairs: if key_name in self.keypairs:
raise InvalidKeyPairDuplicateError(key_name) raise InvalidKeyPairDuplicateError(key_name)
self.keypairs[key_name] = keypair = random_key_pair() keypair = KeyPair(key_name, **random_key_pair())
keypair['name'] = key_name self.keypairs[key_name] = keypair
return keypair return keypair
@ -1017,14 +1051,9 @@ class Ami(TaggedEC2Resource):
return self.state return self.state
elif filter_name == 'name': elif filter_name == 'name':
return self.name return self.name
else:
filter_value = super(Ami, self).get_filter_value(filter_name) return super(Ami, self).get_filter_value(
filter_name, 'DescribeImages')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeImages".format(filter_name))
return filter_value
class AmiBackend(object): class AmiBackend(object):
@ -1144,24 +1173,7 @@ class Zone(object):
class RegionsAndZonesBackend(object): class RegionsAndZonesBackend(object):
regions = [ regions = [Region(ri.name, ri.endpoint) for ri in boto.ec2.regions()]
Region("ap-northeast-1", "ec2.ap-northeast-1.amazonaws.com"),
Region("ap-northeast-2", "ec2.ap-northeast-2.amazonaws.com"),
Region("ap-south-1", "ec2.ap-south-1.amazonaws.com"),
Region("ap-southeast-1", "ec2.ap-southeast-1.amazonaws.com"),
Region("ap-southeast-2", "ec2.ap-southeast-2.amazonaws.com"),
Region("ca-central-1", "ec2.ca-central-1.amazonaws.com.cn"),
Region("cn-north-1", "ec2.cn-north-1.amazonaws.com.cn"),
Region("eu-central-1", "ec2.eu-central-1.amazonaws.com"),
Region("eu-west-1", "ec2.eu-west-1.amazonaws.com"),
Region("eu-west-2", "ec2.eu-west-2.amazonaws.com"),
Region("sa-east-1", "ec2.sa-east-1.amazonaws.com"),
Region("us-east-1", "ec2.us-east-1.amazonaws.com"),
Region("us-east-2", "ec2.us-east-2.amazonaws.com"),
Region("us-gov-west-1", "ec2.us-gov-west-1.amazonaws.com"),
Region("us-west-1", "ec2.us-west-1.amazonaws.com"),
Region("us-west-2", "ec2.us-west-2.amazonaws.com"),
]
zones = dict( zones = dict(
(region, [Zone(region + c, region) for c in 'abc']) (region, [Zone(region + c, region) for c in 'abc'])
@ -1299,7 +1311,7 @@ class SecurityGroup(TaggedEC2Resource):
elif is_tag_filter(key): elif is_tag_filter(key):
tag_value = self.get_filter_value(key) tag_value = self.get_filter_value(key)
if isinstance(filter_value, list): if isinstance(filter_value, list):
return any(v in tag_value for v in filter_value) return tag_filter_matches(self, key, filter_value)
return tag_value in filter_value return tag_value in filter_value
else: else:
attr_name = to_attr(key) attr_name = to_attr(key)
@ -1364,22 +1376,25 @@ class SecurityGroupBackend(object):
return group return group
def describe_security_groups(self, group_ids=None, groupnames=None, filters=None): def describe_security_groups(self, group_ids=None, groupnames=None, filters=None):
all_groups = itertools.chain(*[x.values() matches = itertools.chain(*[x.values()
for x in self.groups.values()]) for x in self.groups.values()])
groups = [] if group_ids:
matches = [grp for grp in matches
if grp.id in group_ids]
if len(group_ids) > len(matches):
unknown_ids = set(group_ids) - set(matches)
raise InvalidSecurityGroupNotFoundError(unknown_ids)
if groupnames:
matches = [grp for grp in matches
if grp.name in groupnames]
if len(groupnames) > len(matches):
unknown_names = set(groupnames) - set(matches)
raise InvalidSecurityGroupNotFoundError(unknown_names)
if filters:
matches = [grp for grp in matches
if grp.matches_filters(filters)]
if group_ids or groupnames or filters: return matches
for group in all_groups:
if ((group_ids and group.id not in group_ids) or
(groupnames and group.name not in groupnames)):
continue
if filters and not group.matches_filters(filters):
continue
groups.append(group)
else:
groups = all_groups
return groups
def _delete_security_group(self, vpc_id, group_id): def _delete_security_group(self, vpc_id, group_id):
if self.groups[vpc_id][group_id].enis: if self.groups[vpc_id][group_id].enis:
@ -1698,43 +1713,31 @@ class Volume(TaggedEC2Resource):
return 'available' return 'available'
def get_filter_value(self, filter_name): def get_filter_value(self, filter_name):
if filter_name.startswith('attachment') and not self.attachment: if filter_name.startswith('attachment') and not self.attachment:
return None return None
if filter_name == 'attachment.attach-time': elif filter_name == 'attachment.attach-time':
return self.attachment.attach_time return self.attachment.attach_time
if filter_name == 'attachment.device': elif filter_name == 'attachment.device':
return self.attachment.device return self.attachment.device
if filter_name == 'attachment.instance-id': elif filter_name == 'attachment.instance-id':
return self.attachment.instance.id return self.attachment.instance.id
if filter_name == 'attachment.status': elif filter_name == 'attachment.status':
return self.attachment.status return self.attachment.status
elif filter_name == 'create-time':
if filter_name == 'create-time':
return self.create_time return self.create_time
elif filter_name == 'size':
if filter_name == 'size':
return self.size return self.size
elif filter_name == 'snapshot-id':
if filter_name == 'snapshot-id':
return self.snapshot_id return self.snapshot_id
elif filter_name == 'status':
if filter_name == 'status':
return self.status return self.status
elif filter_name == 'volume-id':
if filter_name == 'volume-id':
return self.id return self.id
elif filter_name == 'encrypted':
if filter_name == 'encrypted':
return str(self.encrypted).lower() return str(self.encrypted).lower()
else:
filter_value = super(Volume, self).get_filter_value(filter_name) return super(Volume, self).get_filter_value(
filter_name, 'DescribeVolumes')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeVolumes".format(filter_name))
return filter_value
class Snapshot(TaggedEC2Resource): class Snapshot(TaggedEC2Resource):
@ -1749,35 +1752,23 @@ class Snapshot(TaggedEC2Resource):
self.encrypted = encrypted self.encrypted = encrypted
def get_filter_value(self, filter_name): def get_filter_value(self, filter_name):
if filter_name == 'description': if filter_name == 'description':
return self.description return self.description
elif filter_name == 'snapshot-id':
if filter_name == 'snapshot-id':
return self.id return self.id
elif filter_name == 'start-time':
if filter_name == 'start-time':
return self.start_time return self.start_time
elif filter_name == 'volume-id':
if filter_name == 'volume-id':
return self.volume.id return self.volume.id
elif filter_name == 'volume-size':
if filter_name == 'volume-size':
return self.volume.size return self.volume.size
elif filter_name == 'encrypted':
if filter_name == 'encrypted':
return str(self.encrypted).lower() return str(self.encrypted).lower()
elif filter_name == 'status':
if filter_name == 'status':
return self.status return self.status
else:
filter_value = super(Snapshot, self).get_filter_value(filter_name) return super(Snapshot, self).get_filter_value(
filter_name, 'DescribeSnapshots')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeSnapshots".format(filter_name))
return filter_value
class EBSBackend(object): class EBSBackend(object):
@ -1800,11 +1791,17 @@ class EBSBackend(object):
self.volumes[volume_id] = volume self.volumes[volume_id] = volume
return volume return volume
def describe_volumes(self, filters=None): def describe_volumes(self, volume_ids=None, filters=None):
matches = self.volumes.values()
if volume_ids:
matches = [vol for vol in matches
if vol.id in volume_ids]
if len(volume_ids) > len(matches):
unknown_ids = set(volume_ids) - set(matches)
raise InvalidVolumeIdError(unknown_ids)
if filters: if filters:
volumes = self.volumes.values() matches = generic_filter(filters, matches)
return generic_filter(filters, volumes) return matches
return self.volumes.values()
def get_volume(self, volume_id): def get_volume(self, volume_id):
volume = self.volumes.get(volume_id, None) volume = self.volumes.get(volume_id, None)
@ -1856,11 +1853,17 @@ class EBSBackend(object):
self.snapshots[snapshot_id] = snapshot self.snapshots[snapshot_id] = snapshot
return snapshot return snapshot
def describe_snapshots(self, filters=None): def describe_snapshots(self, snapshot_ids=None, filters=None):
matches = self.snapshots.values()
if snapshot_ids:
matches = [snap for snap in matches
if snap.id in snapshot_ids]
if len(snapshot_ids) > len(matches):
unknown_ids = set(snapshot_ids) - set(matches)
raise InvalidSnapshotIdError(unknown_ids)
if filters: if filters:
snapshots = self.snapshots.values() matches = generic_filter(filters, matches)
return generic_filter(filters, snapshots) return matches
return self.snapshots.values()
def get_snapshot(self, snapshot_id): def get_snapshot(self, snapshot_id):
snapshot = self.snapshots.get(snapshot_id, None) snapshot = self.snapshots.get(snapshot_id, None)
@ -1943,16 +1946,10 @@ class VPC(TaggedEC2Resource):
elif filter_name in ('dhcp-options-id', 'dhcpOptionsId'): elif filter_name in ('dhcp-options-id', 'dhcpOptionsId'):
if not self.dhcp_options: if not self.dhcp_options:
return None return None
return self.dhcp_options.id return self.dhcp_options.id
else:
filter_value = super(VPC, self).get_filter_value(filter_name) return super(VPC, self).get_filter_value(
filter_name, 'DescribeVpcs')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeVPCs".format(filter_name))
return filter_value
class VPCBackend(object): class VPCBackend(object):
@ -1985,12 +1982,16 @@ class VPCBackend(object):
return self.vpcs.get(vpc_id) return self.vpcs.get(vpc_id)
def get_all_vpcs(self, vpc_ids=None, filters=None): def get_all_vpcs(self, vpc_ids=None, filters=None):
matches = self.vpcs.values()
if vpc_ids: if vpc_ids:
vpcs = [vpc for vpc in self.vpcs.values() if vpc.id in vpc_ids] matches = [vpc for vpc in matches
else: if vpc.id in vpc_ids]
vpcs = self.vpcs.values() if len(vpc_ids) > len(matches):
unknown_ids = set(vpc_ids) - set(matches)
return generic_filter(filters, vpcs) raise InvalidVPCIdError(unknown_ids)
if filters:
matches = generic_filter(filters, matches)
return matches
def delete_vpc(self, vpc_id): def delete_vpc(self, vpc_id):
# Delete route table if only main route table remains. # Delete route table if only main route table remains.
@ -2186,14 +2187,9 @@ class Subnet(TaggedEC2Resource):
return self.availability_zone return self.availability_zone
elif filter_name in ('defaultForAz', 'default-for-az'): elif filter_name in ('defaultForAz', 'default-for-az'):
return self.default_for_az return self.default_for_az
else:
filter_value = super(Subnet, self).get_filter_value(filter_name) return super(Subnet, self).get_filter_value(
filter_name, 'DescribeSubnets')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeSubnets".format(filter_name))
return filter_value
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -2232,16 +2228,19 @@ class SubnetBackend(object):
return subnet return subnet
def get_all_subnets(self, subnet_ids=None, filters=None): def get_all_subnets(self, subnet_ids=None, filters=None):
subnets = [] # Extract a list of all subnets
matches = itertools.chain(*[x.values()
for x in self.subnets.values()])
if subnet_ids: if subnet_ids:
for subnet_id in subnet_ids: matches = [sn for sn in matches
for items in self.subnets.values(): if sn.id in subnet_ids]
if subnet_id in items: if len(subnet_ids) > len(matches):
subnets.append(items[subnet_id]) unknown_ids = set(subnet_ids) - set(matches)
else: raise InvalidSubnetIdError(unknown_ids)
for items in self.subnets.values(): if filters:
subnets.extend(items.values()) matches = generic_filter(filters, matches)
return generic_filter(filters, subnets)
return matches
def delete_subnet(self, subnet_id): def delete_subnet(self, subnet_id):
for subnets in self.subnets.values(): for subnets in self.subnets.values():
@ -2331,14 +2330,9 @@ class RouteTable(TaggedEC2Resource):
return self.associations.keys() return self.associations.keys()
elif filter_name == "association.subnet-id": elif filter_name == "association.subnet-id":
return self.associations.values() return self.associations.values()
else:
filter_value = super(RouteTable, self).get_filter_value(filter_name) return super(RouteTable, self).get_filter_value(
filter_name, 'DescribeRouteTables')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeRouteTables".format(filter_name))
return filter_value
class RouteTableBackend(object): class RouteTableBackend(object):
@ -2644,7 +2638,7 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
def __init__(self, ec2_backend, spot_request_id, price, image_id, type, def __init__(self, ec2_backend, spot_request_id, price, image_id, type,
valid_from, valid_until, launch_group, availability_zone_group, valid_from, valid_until, launch_group, availability_zone_group,
key_name, security_groups, user_data, instance_type, placement, key_name, security_groups, user_data, instance_type, placement,
kernel_id, ramdisk_id, monitoring_enabled, subnet_id, kernel_id, ramdisk_id, monitoring_enabled, subnet_id, spot_fleet_id,
**kwargs): **kwargs):
super(SpotInstanceRequest, self).__init__(**kwargs) super(SpotInstanceRequest, self).__init__(**kwargs)
ls = LaunchSpecification() ls = LaunchSpecification()
@ -2667,6 +2661,7 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
ls.placement = placement ls.placement = placement
ls.monitored = monitoring_enabled ls.monitored = monitoring_enabled
ls.subnet_id = subnet_id ls.subnet_id = subnet_id
self.spot_fleet_id = spot_fleet_id
if security_groups: if security_groups:
for group_name in security_groups: for group_name in security_groups:
@ -2685,16 +2680,11 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
def get_filter_value(self, filter_name): def get_filter_value(self, filter_name):
if filter_name == 'state': if filter_name == 'state':
return self.state return self.state
if filter_name == 'spot-instance-request-id': elif filter_name == 'spot-instance-request-id':
return self.id return self.id
filter_value = super(SpotInstanceRequest, else:
self).get_filter_value(filter_name) return super(SpotInstanceRequest, self).get_filter_value(
filter_name, 'DescribeSpotInstanceRequests')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeSpotInstanceRequests".format(filter_name))
return filter_value
def launch_instance(self): def launch_instance(self):
reservation = self.ec2_backend.add_instances( reservation = self.ec2_backend.add_instances(
@ -2704,6 +2694,7 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
key_name=self.launch_specification.key_name, key_name=self.launch_specification.key_name,
security_group_names=[], security_group_names=[],
security_group_ids=self.launch_specification.groups, security_group_ids=self.launch_specification.groups,
spot_fleet_id=self.spot_fleet_id,
) )
instance = reservation.instances[0] instance = reservation.instances[0]
return instance return instance
@ -2719,7 +2710,7 @@ class SpotRequestBackend(object):
valid_until, launch_group, availability_zone_group, valid_until, launch_group, availability_zone_group,
key_name, security_groups, user_data, key_name, security_groups, user_data,
instance_type, placement, kernel_id, ramdisk_id, instance_type, placement, kernel_id, ramdisk_id,
monitoring_enabled, subnet_id): monitoring_enabled, subnet_id, spot_fleet_id=None):
requests = [] requests = []
for _ in range(count): for _ in range(count):
spot_request_id = random_spot_request_id() spot_request_id = random_spot_request_id()
@ -2727,7 +2718,7 @@ class SpotRequestBackend(object):
spot_request_id, price, image_id, type, valid_from, valid_until, spot_request_id, price, image_id, type, valid_from, valid_until,
launch_group, availability_zone_group, key_name, security_groups, launch_group, availability_zone_group, key_name, security_groups,
user_data, instance_type, placement, kernel_id, ramdisk_id, user_data, instance_type, placement, kernel_id, ramdisk_id,
monitoring_enabled, subnet_id) monitoring_enabled, subnet_id, spot_fleet_id)
self.spot_instance_requests[spot_request_id] = request self.spot_instance_requests[spot_request_id] = request
requests.append(request) requests.append(request)
return requests return requests
@ -2773,7 +2764,7 @@ class SpotFleetRequest(TaggedEC2Resource):
self.iam_fleet_role = iam_fleet_role self.iam_fleet_role = iam_fleet_role
self.allocation_strategy = allocation_strategy self.allocation_strategy = allocation_strategy
self.state = "active" self.state = "active"
self.fulfilled_capacity = self.target_capacity self.fulfilled_capacity = 0.0
self.launch_specs = [] self.launch_specs = []
for spec in launch_specs: for spec in launch_specs:
@ -2794,7 +2785,7 @@ class SpotFleetRequest(TaggedEC2Resource):
) )
self.spot_requests = [] self.spot_requests = []
self.create_spot_requests() self.create_spot_requests(self.target_capacity)
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -2824,31 +2815,32 @@ class SpotFleetRequest(TaggedEC2Resource):
return spot_fleet_request return spot_fleet_request
def get_launch_spec_counts(self): def get_launch_spec_counts(self, weight_to_add):
weight_map = defaultdict(int) weight_map = defaultdict(int)
weight_so_far = 0
if self.allocation_strategy == 'diversified': if self.allocation_strategy == 'diversified':
weight_so_far = 0
launch_spec_index = 0 launch_spec_index = 0
while True: while True:
launch_spec = self.launch_specs[ launch_spec = self.launch_specs[
launch_spec_index % len(self.launch_specs)] launch_spec_index % len(self.launch_specs)]
weight_map[launch_spec] += 1 weight_map[launch_spec] += 1
weight_so_far += launch_spec.weighted_capacity weight_so_far += launch_spec.weighted_capacity
if weight_so_far >= self.target_capacity: if weight_so_far >= weight_to_add:
break break
launch_spec_index += 1 launch_spec_index += 1
else: # lowestPrice else: # lowestPrice
cheapest_spec = sorted( cheapest_spec = sorted(
self.launch_specs, key=lambda spec: float(spec.spot_price))[0] self.launch_specs, key=lambda spec: float(spec.spot_price))[0]
extra = 1 if self.target_capacity % cheapest_spec.weighted_capacity else 0 weight_so_far = weight_to_add + (weight_to_add % cheapest_spec.weighted_capacity)
weight_map[cheapest_spec] = int( weight_map[cheapest_spec] = int(
self.target_capacity // cheapest_spec.weighted_capacity) + extra weight_so_far // cheapest_spec.weighted_capacity)
return weight_map.items() return weight_map, weight_so_far
def create_spot_requests(self): def create_spot_requests(self, weight_to_add):
for launch_spec, count in self.get_launch_spec_counts(): weight_map, added_weight = self.get_launch_spec_counts(weight_to_add)
for launch_spec, count in weight_map.items():
requests = self.ec2_backend.request_spot_instances( requests = self.ec2_backend.request_spot_instances(
price=launch_spec.spot_price, price=launch_spec.spot_price,
image_id=launch_spec.image_id, image_id=launch_spec.image_id,
@ -2867,12 +2859,28 @@ class SpotFleetRequest(TaggedEC2Resource):
ramdisk_id=None, ramdisk_id=None,
monitoring_enabled=launch_spec.monitoring, monitoring_enabled=launch_spec.monitoring,
subnet_id=launch_spec.subnet_id, subnet_id=launch_spec.subnet_id,
spot_fleet_id=self.id,
) )
self.spot_requests.extend(requests) self.spot_requests.extend(requests)
self.fulfilled_capacity += added_weight
return self.spot_requests return self.spot_requests
def terminate_instances(self): def terminate_instances(self):
pass instance_ids = []
new_fulfilled_capacity = self.fulfilled_capacity
for req in self.spot_requests:
instance = req.instance
for spec in self.launch_specs:
if spec.instance_type == instance.instance_type and spec.subnet_id == instance.subnet_id:
break
if new_fulfilled_capacity - spec.weighted_capacity < self.target_capacity:
continue
new_fulfilled_capacity -= spec.weighted_capacity
instance_ids.append(instance.id)
self.spot_requests = [req for req in self.spot_requests if req.instance.id not in instance_ids]
self.ec2_backend.terminate_instances(instance_ids)
class SpotFleetBackend(object): class SpotFleetBackend(object):
@ -2908,12 +2916,26 @@ class SpotFleetBackend(object):
def cancel_spot_fleet_requests(self, spot_fleet_request_ids, terminate_instances): def cancel_spot_fleet_requests(self, spot_fleet_request_ids, terminate_instances):
spot_requests = [] spot_requests = []
for spot_fleet_request_id in spot_fleet_request_ids: for spot_fleet_request_id in spot_fleet_request_ids:
spot_fleet = self.spot_fleet_requests.pop(spot_fleet_request_id) spot_fleet = self.spot_fleet_requests[spot_fleet_request_id]
if terminate_instances: if terminate_instances:
spot_fleet.target_capacity = 0
spot_fleet.terminate_instances() spot_fleet.terminate_instances()
spot_requests.append(spot_fleet) spot_requests.append(spot_fleet)
del self.spot_fleet_requests[spot_fleet_request_id]
return spot_requests return spot_requests
def modify_spot_fleet_request(self, spot_fleet_request_id, target_capacity, terminate_instances):
if target_capacity < 0:
raise ValueError('Cannot reduce spot fleet capacity below 0')
spot_fleet_request = self.spot_fleet_requests[spot_fleet_request_id]
delta = target_capacity - spot_fleet_request.fulfilled_capacity
spot_fleet_request.target_capacity = target_capacity
if delta > 0:
spot_fleet_request.create_spot_requests(delta)
elif delta < 0 and terminate_instances == 'Default':
spot_fleet_request.terminate_instances()
return True
class ElasticAddress(object): class ElasticAddress(object):
def __init__(self, domain): def __init__(self, domain):
@ -2954,6 +2976,25 @@ class ElasticAddress(object):
return self.allocation_id return self.allocation_id
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
def get_filter_value(self, filter_name):
if filter_name == 'allocation-id':
return self.allocation_id
elif filter_name == 'association-id':
return self.association_id
elif filter_name == 'domain':
return self.domain
elif filter_name == 'instance-id' and self.instance:
return self.instance.id
elif filter_name == 'network-interface-id' and self.eni:
return self.eni.id
elif filter_name == 'private-ip-address' and self.eni:
return self.eni.private_ip_address
elif filter_name == 'public-ip':
return self.public_ip
else:
# TODO: implement network-interface-owner-id
raise FilterNotImplementedError(filter_name, 'DescribeAddresses')
class ElasticAddressBackend(object): class ElasticAddressBackend(object):
def __init__(self): def __init__(self):
@ -3014,19 +3055,36 @@ class ElasticAddressBackend(object):
if new_instance_association or new_eni_association or reassociate: if new_instance_association or new_eni_association or reassociate:
eip.instance = instance eip.instance = instance
eip.eni = eni eip.eni = eni
if not eip.eni and instance:
# default to primary network interface
eip.eni = instance.nics[0]
if eip.eni: if eip.eni:
eip.eni.public_ip = eip.public_ip eip.eni.public_ip = eip.public_ip
if eip.domain == "vpc": if eip.domain == "vpc":
eip.association_id = random_eip_association_id() eip.association_id = random_eip_association_id()
if instance:
instance.set_ip(eip.public_ip)
return eip return eip
raise ResourceAlreadyAssociatedError(eip.public_ip) raise ResourceAlreadyAssociatedError(eip.public_ip)
def describe_addresses(self): def describe_addresses(self, allocation_ids=None, public_ips=None, filters=None):
return self.addresses matches = self.addresses
if allocation_ids:
matches = [addr for addr in matches
if addr.allocation_id in allocation_ids]
if len(allocation_ids) > len(matches):
unknown_ids = set(allocation_ids) - set(matches)
raise InvalidAllocationIdError(unknown_ids)
if public_ips:
matches = [addr for addr in matches
if addr.public_ip in public_ips]
if len(public_ips) > len(matches):
unknown_ips = set(allocation_ids) - set(matches)
raise InvalidAddressError(unknown_ips)
if filters:
matches = generic_filter(filters, matches)
return matches
def disassociate_address(self, address=None, association_id=None): def disassociate_address(self, address=None, association_id=None):
eips = [] eips = []
@ -3037,10 +3095,9 @@ class ElasticAddressBackend(object):
eip = eips[0] eip = eips[0]
if eip.eni: if eip.eni:
eip.eni.public_ip = None
if eip.eni.instance and eip.eni.instance._state.name == "running": if eip.eni.instance and eip.eni.instance._state.name == "running":
eip.eni.check_auto_public_ip() eip.eni.check_auto_public_ip()
else:
eip.eni.public_ip = None
eip.eni = None eip.eni = None
eip.instance = None eip.instance = None
@ -3096,15 +3153,9 @@ class DHCPOptionsSet(TaggedEC2Resource):
elif filter_name == 'value': elif filter_name == 'value':
values = [item for item in list(self._options.values()) if item] values = [item for item in list(self._options.values()) if item]
return itertools.chain(*values) return itertools.chain(*values)
else:
filter_value = super( return super(DHCPOptionsSet, self).get_filter_value(
DHCPOptionsSet, self).get_filter_value(filter_name) filter_name, 'DescribeDhcpOptions')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeDhcpOptions".format(filter_name))
return filter_value
@property @property
def options(self): def options(self):
@ -3191,6 +3242,10 @@ class VPNConnection(TaggedEC2Resource):
self.options = None self.options = None
self.static_routes = None self.static_routes = None
def get_filter_value(self, filter_name):
return super(VPNConnection, self).get_filter_value(
filter_name, 'DescribeVpnConnections')
class VPNConnectionBackend(object): class VPNConnectionBackend(object):
def __init__(self): def __init__(self):
@ -3370,14 +3425,9 @@ class NetworkAcl(TaggedEC2Resource):
return self.id return self.id
elif filter_name == "association.subnet-id": elif filter_name == "association.subnet-id":
return [assoc.subnet_id for assoc in self.associations.values()] return [assoc.subnet_id for assoc in self.associations.values()]
else:
filter_value = super(NetworkAcl, self).get_filter_value(filter_name) return super(NetworkAcl, self).get_filter_value(
filter_name, 'DescribeNetworkAcls')
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeNetworkAcls".format(filter_name))
return filter_value
class NetworkAclEntry(TaggedEC2Resource): class NetworkAclEntry(TaggedEC2Resource):
@ -3406,6 +3456,10 @@ class VpnGateway(TaggedEC2Resource):
self.attachments = {} self.attachments = {}
super(VpnGateway, self).__init__() super(VpnGateway, self).__init__()
def get_filter_value(self, filter_name):
return super(VpnGateway, self).get_filter_value(
filter_name, 'DescribeVpnGateways')
class VpnGatewayAttachment(object): class VpnGatewayAttachment(object):
def __init__(self, vpc_id, state): def __init__(self, vpc_id, state):
@ -3467,6 +3521,10 @@ class CustomerGateway(TaggedEC2Resource):
self.attachments = {} self.attachments = {}
super(CustomerGateway, self).__init__() super(CustomerGateway, self).__init__()
def get_filter_value(self, filter_name):
return super(CustomerGateway, self).get_filter_value(
filter_name, 'DescribeCustomerGateways')
class CustomerGatewayBackend(object): class CustomerGatewayBackend(object):
def __init__(self): def __init__(self):
@ -3573,8 +3631,8 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, AmiBackend,
DHCPOptionsSetBackend, NetworkAclBackend, VpnGatewayBackend, DHCPOptionsSetBackend, NetworkAclBackend, VpnGatewayBackend,
CustomerGatewayBackend, NatGatewayBackend): CustomerGatewayBackend, NatGatewayBackend):
def __init__(self, region_name): def __init__(self, region_name):
super(EC2Backend, self).__init__()
self.region_name = region_name self.region_name = region_name
super(EC2Backend, self).__init__()
# Default VPC exists by default, which is the current behavior # Default VPC exists by default, which is the current behavior
# of EC2-VPC. See for detail: # of EC2-VPC. See for detail:
@ -3610,10 +3668,7 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, AmiBackend,
raise EC2ClientError(code, message) raise EC2ClientError(code, message)
def raise_not_implemented_error(self, blurb): def raise_not_implemented_error(self, blurb):
msg = "{0} has not been implemented in Moto yet." \ raise MotoNotImplementedError(blurb)
" Feel free to open an issue at" \
" https://github.com/spulec/moto/issues".format(blurb)
raise NotImplementedError(msg)
def do_resources_exist(self, resource_ids): def do_resources_exist(self, resource_ids):
for resource_id in resource_ids: for resource_id in resource_ids:
@ -3660,6 +3715,5 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, AmiBackend,
return True return True
ec2_backends = {} ec2_backends = {region.name: EC2Backend(region.name)
for region in RegionsAndZonesBackend.regions: for region in RegionsAndZonesBackend.regions}
ec2_backends[region.name] = EC2Backend(region.name)

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .account_attributes import AccountAttributes
from .amazon_dev_pay import AmazonDevPay from .amazon_dev_pay import AmazonDevPay
from .amis import AmisResponse from .amis import AmisResponse
from .availability_zones_and_regions import AvailabilityZonesAndRegions from .availability_zones_and_regions import AvailabilityZonesAndRegions
@ -34,6 +35,7 @@ from .nat_gateways import NatGateways
class EC2Response( class EC2Response(
AccountAttributes,
AmazonDevPay, AmazonDevPay,
AmisResponse, AmisResponse,
AvailabilityZonesAndRegions, AvailabilityZonesAndRegions,

View File

@ -0,0 +1,69 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
class AccountAttributes(BaseResponse):
def describe_account_attributes(self):
template = self.response_template(DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT)
return template.render()
DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = u"""
<DescribeAccountAttributesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<accountAttributeSet>
<item>
<attributeName>vpc-max-security-groups-per-interface</attributeName>
<attributeValueSet>
<item>
<attributeValue>5</attributeValue>
</item>
</attributeValueSet>
</item>
<item>
<attributeName>max-instances</attributeName>
<attributeValueSet>
<item>
<attributeValue>20</attributeValue>
</item>
</attributeValueSet>
</item>
<item>
<attributeName>supported-platforms</attributeName>
<attributeValueSet>
<item>
<attributeValue>EC2</attributeValue>
</item>
<item>
<attributeValue>VPC</attributeValue>
</item>
</attributeValueSet>
</item>
<item>
<attributeName>default-vpc</attributeName>
<attributeValueSet>
<item>
<attributeValue>none</attributeValue>
</item>
</attributeValueSet>
</item>
<item>
<attributeName>max-elastic-ips</attributeName>
<attributeValueSet>
<item>
<attributeValue>5</attributeValue>
</item>
</attributeValueSet>
</item>
<item>
<attributeName>vpc-max-elastic-ips</attributeName>
<attributeValueSet>
<item>
<attributeValue>5</attributeValue>
</item>
</attributeValueSet>
</item>
</accountAttributeSet>
</DescribeAccountAttributesResponse>
"""

View File

@ -1,19 +1,14 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import instance_ids_from_querystring, image_ids_from_querystring, \ from moto.ec2.utils import filters_from_querystring
filters_from_querystring, sequence_from_querystring, executable_users_from_querystring
class AmisResponse(BaseResponse): class AmisResponse(BaseResponse):
def create_image(self): def create_image(self):
name = self.querystring.get('Name')[0] name = self.querystring.get('Name')[0]
if "Description" in self.querystring: description = self._get_param('Description', if_none='')
description = self.querystring.get('Description')[0] instance_id = self._get_param('InstanceId')
else:
description = ""
instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0]
if self.is_not_dryrun('CreateImage'): if self.is_not_dryrun('CreateImage'):
image = self.ec2_backend.create_image( image = self.ec2_backend.create_image(
instance_id, name, description) instance_id, name, description)
@ -21,12 +16,10 @@ class AmisResponse(BaseResponse):
return template.render(image=image) return template.render(image=image)
def copy_image(self): def copy_image(self):
source_image_id = self.querystring.get('SourceImageId')[0] source_image_id = self._get_param('SourceImageId')
source_region = self.querystring.get('SourceRegion')[0] source_region = self._get_param('SourceRegion')
name = self.querystring.get( name = self._get_param('Name')
'Name')[0] if self.querystring.get('Name') else None description = self._get_param('Description')
description = self.querystring.get(
'Description')[0] if self.querystring.get('Description') else None
if self.is_not_dryrun('CopyImage'): if self.is_not_dryrun('CopyImage'):
image = self.ec2_backend.copy_image( image = self.ec2_backend.copy_image(
source_image_id, source_region, name, description) source_image_id, source_region, name, description)
@ -34,33 +27,33 @@ class AmisResponse(BaseResponse):
return template.render(image=image) return template.render(image=image)
def deregister_image(self): def deregister_image(self):
ami_id = self.querystring.get('ImageId')[0] ami_id = self._get_param('ImageId')
if self.is_not_dryrun('DeregisterImage'): if self.is_not_dryrun('DeregisterImage'):
success = self.ec2_backend.deregister_image(ami_id) success = self.ec2_backend.deregister_image(ami_id)
template = self.response_template(DEREGISTER_IMAGE_RESPONSE) template = self.response_template(DEREGISTER_IMAGE_RESPONSE)
return template.render(success=str(success).lower()) return template.render(success=str(success).lower())
def describe_images(self): def describe_images(self):
ami_ids = image_ids_from_querystring(self.querystring) ami_ids = self._get_multi_param('ImageId')
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
exec_users = executable_users_from_querystring(self.querystring) exec_users = self._get_multi_param('ExecutableBy')
images = self.ec2_backend.describe_images( images = self.ec2_backend.describe_images(
ami_ids=ami_ids, filters=filters, exec_users=exec_users) ami_ids=ami_ids, filters=filters, exec_users=exec_users)
template = self.response_template(DESCRIBE_IMAGES_RESPONSE) template = self.response_template(DESCRIBE_IMAGES_RESPONSE)
return template.render(images=images) return template.render(images=images)
def describe_image_attribute(self): def describe_image_attribute(self):
ami_id = self.querystring.get('ImageId')[0] ami_id = self._get_param('ImageId')
groups = self.ec2_backend.get_launch_permission_groups(ami_id) groups = self.ec2_backend.get_launch_permission_groups(ami_id)
users = self.ec2_backend.get_launch_permission_users(ami_id) users = self.ec2_backend.get_launch_permission_users(ami_id)
template = self.response_template(DESCRIBE_IMAGE_ATTRIBUTES_RESPONSE) template = self.response_template(DESCRIBE_IMAGE_ATTRIBUTES_RESPONSE)
return template.render(ami_id=ami_id, groups=groups, users=users) return template.render(ami_id=ami_id, groups=groups, users=users)
def modify_image_attribute(self): def modify_image_attribute(self):
ami_id = self.querystring.get('ImageId')[0] ami_id = self._get_param('ImageId')
operation_type = self.querystring.get('OperationType')[0] operation_type = self._get_param('OperationType')
group = self.querystring.get('UserGroup.1', [None])[0] group = self._get_param('UserGroup.1')
user_ids = sequence_from_querystring('UserId', self.querystring) user_ids = self._get_multi_param('UserId')
if self.is_not_dryrun('ModifyImageAttribute'): if self.is_not_dryrun('ModifyImageAttribute'):
if (operation_type == 'add'): if (operation_type == 'add'):
self.ec2_backend.add_launch_permission( self.ec2_backend.add_launch_permission(
@ -115,7 +108,7 @@ DESCRIBE_IMAGES_RESPONSE = """<DescribeImagesResponse xmlns="http://ec2.amazonaw
{% endif %} {% endif %}
<description>{{ image.description }}</description> <description>{{ image.description }}</description>
<rootDeviceType>ebs</rootDeviceType> <rootDeviceType>ebs</rootDeviceType>
<rootDeviceName>/dev/sda</rootDeviceName> <rootDeviceName>/dev/sda1</rootDeviceName>
<blockDeviceMapping> <blockDeviceMapping>
<item> <item>
<deviceName>/dev/sda1</deviceName> <deviceName>/dev/sda1</deviceName>

View File

@ -7,16 +7,16 @@ class CustomerGateways(BaseResponse):
def create_customer_gateway(self): def create_customer_gateway(self):
# raise NotImplementedError('CustomerGateways(AmazonVPC).create_customer_gateway is not yet implemented') # raise NotImplementedError('CustomerGateways(AmazonVPC).create_customer_gateway is not yet implemented')
type = self.querystring.get('Type', None)[0] type = self._get_param('Type')
ip_address = self.querystring.get('IpAddress', None)[0] ip_address = self._get_param('IpAddress')
bgp_asn = self.querystring.get('BgpAsn', None)[0] bgp_asn = self._get_param('BgpAsn')
customer_gateway = self.ec2_backend.create_customer_gateway( customer_gateway = self.ec2_backend.create_customer_gateway(
type, ip_address=ip_address, bgp_asn=bgp_asn) type, ip_address=ip_address, bgp_asn=bgp_asn)
template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE) template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=customer_gateway) return template.render(customer_gateway=customer_gateway)
def delete_customer_gateway(self): def delete_customer_gateway(self):
customer_gateway_id = self.querystring.get('CustomerGatewayId')[0] customer_gateway_id = self._get_param('CustomerGatewayId')
delete_status = self.ec2_backend.delete_customer_gateway( delete_status = self.ec2_backend.delete_customer_gateway(
customer_gateway_id) customer_gateway_id)
template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE) template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE)

View File

@ -2,15 +2,14 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import ( from moto.ec2.utils import (
filters_from_querystring, filters_from_querystring,
sequence_from_querystring,
dhcp_configuration_from_querystring) dhcp_configuration_from_querystring)
class DHCPOptions(BaseResponse): class DHCPOptions(BaseResponse):
def associate_dhcp_options(self): def associate_dhcp_options(self):
dhcp_opt_id = self.querystring.get("DhcpOptionsId", [None])[0] dhcp_opt_id = self._get_param('DhcpOptionsId')
vpc_id = self.querystring.get("VpcId", [None])[0] vpc_id = self._get_param('VpcId')
dhcp_opt = self.ec2_backend.describe_dhcp_options([dhcp_opt_id])[0] dhcp_opt = self.ec2_backend.describe_dhcp_options([dhcp_opt_id])[0]
vpc = self.ec2_backend.get_vpc(vpc_id) vpc = self.ec2_backend.get_vpc(vpc_id)
@ -43,14 +42,13 @@ class DHCPOptions(BaseResponse):
return template.render(dhcp_options_set=dhcp_options_set) return template.render(dhcp_options_set=dhcp_options_set)
def delete_dhcp_options(self): def delete_dhcp_options(self):
dhcp_opt_id = self.querystring.get("DhcpOptionsId", [None])[0] dhcp_opt_id = self._get_param('DhcpOptionsId')
delete_status = self.ec2_backend.delete_dhcp_options_set(dhcp_opt_id) delete_status = self.ec2_backend.delete_dhcp_options_set(dhcp_opt_id)
template = self.response_template(DELETE_DHCP_OPTIONS_RESPONSE) template = self.response_template(DELETE_DHCP_OPTIONS_RESPONSE)
return template.render(delete_status=delete_status) return template.render(delete_status=delete_status)
def describe_dhcp_options(self): def describe_dhcp_options(self):
dhcp_opt_ids = sequence_from_querystring( dhcp_opt_ids = self._get_multi_param("DhcpOptionsId")
"DhcpOptionsId", self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
dhcp_opts = self.ec2_backend.get_all_dhcp_options( dhcp_opts = self.ec2_backend.get_all_dhcp_options(
dhcp_opt_ids, filters) dhcp_opt_ids, filters)

View File

@ -6,9 +6,9 @@ from moto.ec2.utils import filters_from_querystring
class ElasticBlockStore(BaseResponse): class ElasticBlockStore(BaseResponse):
def attach_volume(self): def attach_volume(self):
volume_id = self.querystring.get('VolumeId')[0] volume_id = self._get_param('VolumeId')
instance_id = self.querystring.get('InstanceId')[0] instance_id = self._get_param('InstanceId')
device_path = self.querystring.get('Device')[0] device_path = self._get_param('Device')
if self.is_not_dryrun('AttachVolume'): if self.is_not_dryrun('AttachVolume'):
attachment = self.ec2_backend.attach_volume( attachment = self.ec2_backend.attach_volume(
volume_id, instance_id, device_path) volume_id, instance_id, device_path)
@ -21,18 +21,18 @@ class ElasticBlockStore(BaseResponse):
'ElasticBlockStore.copy_snapshot is not yet implemented') 'ElasticBlockStore.copy_snapshot is not yet implemented')
def create_snapshot(self): def create_snapshot(self):
description = self.querystring.get('Description', [None])[0] volume_id = self._get_param('VolumeId')
volume_id = self.querystring.get('VolumeId')[0] description = self._get_param('Description')
if self.is_not_dryrun('CreateSnapshot'): if self.is_not_dryrun('CreateSnapshot'):
snapshot = self.ec2_backend.create_snapshot(volume_id, description) snapshot = self.ec2_backend.create_snapshot(volume_id, description)
template = self.response_template(CREATE_SNAPSHOT_RESPONSE) template = self.response_template(CREATE_SNAPSHOT_RESPONSE)
return template.render(snapshot=snapshot) return template.render(snapshot=snapshot)
def create_volume(self): def create_volume(self):
size = self.querystring.get('Size', [None])[0] size = self._get_param('Size')
zone = self.querystring.get('AvailabilityZone', [None])[0] zone = self._get_param('AvailabilityZone')
snapshot_id = self.querystring.get('SnapshotId', [None])[0] snapshot_id = self._get_param('SnapshotId')
encrypted = self.querystring.get('Encrypted', ['false'])[0] encrypted = self._get_param('Encrypted', if_none=False)
if self.is_not_dryrun('CreateVolume'): if self.is_not_dryrun('CreateVolume'):
volume = self.ec2_backend.create_volume( volume = self.ec2_backend.create_volume(
size, zone, snapshot_id, encrypted) size, zone, snapshot_id, encrypted)
@ -40,40 +40,28 @@ class ElasticBlockStore(BaseResponse):
return template.render(volume=volume) return template.render(volume=volume)
def delete_snapshot(self): def delete_snapshot(self):
snapshot_id = self.querystring.get('SnapshotId')[0] snapshot_id = self._get_param('SnapshotId')
if self.is_not_dryrun('DeleteSnapshot'): if self.is_not_dryrun('DeleteSnapshot'):
self.ec2_backend.delete_snapshot(snapshot_id) self.ec2_backend.delete_snapshot(snapshot_id)
return DELETE_SNAPSHOT_RESPONSE return DELETE_SNAPSHOT_RESPONSE
def delete_volume(self): def delete_volume(self):
volume_id = self.querystring.get('VolumeId')[0] volume_id = self._get_param('VolumeId')
if self.is_not_dryrun('DeleteVolume'): if self.is_not_dryrun('DeleteVolume'):
self.ec2_backend.delete_volume(volume_id) self.ec2_backend.delete_volume(volume_id)
return DELETE_VOLUME_RESPONSE return DELETE_VOLUME_RESPONSE
def describe_snapshots(self): def describe_snapshots(self):
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
# querystring for multiple snapshotids results in SnapshotId.1, snapshot_ids = self._get_multi_param('SnapshotId')
# SnapshotId.2 etc snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters)
snapshot_ids = ','.join(
[','.join(s[1]) for s in self.querystring.items() if 'SnapshotId' in s[0]])
snapshots = self.ec2_backend.describe_snapshots(filters=filters)
# Describe snapshots to handle filter on snapshot_ids
snapshots = [
s for s in snapshots if s.id in snapshot_ids] if snapshot_ids else snapshots
template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE) template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE)
return template.render(snapshots=snapshots) return template.render(snapshots=snapshots)
def describe_volumes(self): def describe_volumes(self):
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
# querystring for multiple volumeids results in VolumeId.1, VolumeId.2 volume_ids = self._get_multi_param('VolumeId')
# etc volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters)
volume_ids = ','.join(
[','.join(v[1]) for v in self.querystring.items() if 'VolumeId' in v[0]])
volumes = self.ec2_backend.describe_volumes(filters=filters)
# Describe volumes to handle filter on volume_ids
volumes = [
v for v in volumes if v.id in volume_ids] if volume_ids else volumes
template = self.response_template(DESCRIBE_VOLUMES_RESPONSE) template = self.response_template(DESCRIBE_VOLUMES_RESPONSE)
return template.render(volumes=volumes) return template.render(volumes=volumes)
@ -86,9 +74,9 @@ class ElasticBlockStore(BaseResponse):
'ElasticBlockStore.describe_volume_status is not yet implemented') 'ElasticBlockStore.describe_volume_status is not yet implemented')
def detach_volume(self): def detach_volume(self):
volume_id = self.querystring.get('VolumeId')[0] volume_id = self._get_param('VolumeId')
instance_id = self.querystring.get('InstanceId')[0] instance_id = self._get_param('InstanceId')
device_path = self.querystring.get('Device')[0] device_path = self._get_param('Device')
if self.is_not_dryrun('DetachVolume'): if self.is_not_dryrun('DetachVolume'):
attachment = self.ec2_backend.detach_volume( attachment = self.ec2_backend.detach_volume(
volume_id, instance_id, device_path) volume_id, instance_id, device_path)
@ -106,7 +94,7 @@ class ElasticBlockStore(BaseResponse):
'ElasticBlockStore.import_volume is not yet implemented') 'ElasticBlockStore.import_volume is not yet implemented')
def describe_snapshot_attribute(self): def describe_snapshot_attribute(self):
snapshot_id = self.querystring.get('SnapshotId')[0] snapshot_id = self._get_param('SnapshotId')
groups = self.ec2_backend.get_create_volume_permission_groups( groups = self.ec2_backend.get_create_volume_permission_groups(
snapshot_id) snapshot_id)
template = self.response_template( template = self.response_template(
@ -114,10 +102,10 @@ class ElasticBlockStore(BaseResponse):
return template.render(snapshot_id=snapshot_id, groups=groups) return template.render(snapshot_id=snapshot_id, groups=groups)
def modify_snapshot_attribute(self): def modify_snapshot_attribute(self):
snapshot_id = self.querystring.get('SnapshotId')[0] snapshot_id = self._get_param('SnapshotId')
operation_type = self.querystring.get('OperationType')[0] operation_type = self._get_param('OperationType')
group = self.querystring.get('UserGroup.1', [None])[0] group = self._get_param('UserGroup.1')
user_id = self.querystring.get('UserId.1', [None])[0] user_id = self._get_param('UserId.1')
if self.is_not_dryrun('ModifySnapshotAttribute'): if self.is_not_dryrun('ModifySnapshotAttribute'):
if (operation_type == 'add'): if (operation_type == 'add'):
self.ec2_backend.add_create_volume_permission( self.ec2_backend.add_create_volume_permission(

View File

@ -1,15 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import sequence_from_querystring from moto.ec2.utils import filters_from_querystring
class ElasticIPAddresses(BaseResponse): class ElasticIPAddresses(BaseResponse):
def allocate_address(self): def allocate_address(self):
if "Domain" in self.querystring: domain = self._get_param('Domain', if_none='standard')
domain = self.querystring.get('Domain')[0]
else:
domain = "standard"
if self.is_not_dryrun('AllocateAddress'): if self.is_not_dryrun('AllocateAddress'):
address = self.ec2_backend.allocate_address(domain) address = self.ec2_backend.allocate_address(domain)
template = self.response_template(ALLOCATE_ADDRESS_RESPONSE) template = self.response_template(ALLOCATE_ADDRESS_RESPONSE)
@ -20,26 +17,28 @@ class ElasticIPAddresses(BaseResponse):
if "InstanceId" in self.querystring: if "InstanceId" in self.querystring:
instance = self.ec2_backend.get_instance( instance = self.ec2_backend.get_instance(
self.querystring['InstanceId'][0]) self._get_param('InstanceId'))
elif "NetworkInterfaceId" in self.querystring: elif "NetworkInterfaceId" in self.querystring:
eni = self.ec2_backend.get_network_interface( eni = self.ec2_backend.get_network_interface(
self.querystring['NetworkInterfaceId'][0]) self._get_param('NetworkInterfaceId'))
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.") "MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.")
reassociate = False reassociate = False
if "AllowReassociation" in self.querystring: if "AllowReassociation" in self.querystring:
reassociate = self.querystring['AllowReassociation'][0] == "true" reassociate = self._get_param('AllowReassociation') == "true"
if self.is_not_dryrun('AssociateAddress'): if self.is_not_dryrun('AssociateAddress'):
if instance or eni: if instance or eni:
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
eip = self.ec2_backend.associate_address(instance=instance, eni=eni, address=self.querystring[ eip = self.ec2_backend.associate_address(
'PublicIp'][0], reassociate=reassociate) instance=instance, eni=eni,
address=self._get_param('PublicIp'), reassociate=reassociate)
elif "AllocationId" in self.querystring: elif "AllocationId" in self.querystring:
eip = self.ec2_backend.associate_address(instance=instance, eni=eni, allocation_id=self.querystring[ eip = self.ec2_backend.associate_address(
'AllocationId'][0], reassociate=reassociate) instance=instance, eni=eni,
allocation_id=self._get_param('AllocationId'), reassociate=reassociate)
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") "MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.")
@ -51,39 +50,22 @@ class ElasticIPAddresses(BaseResponse):
return template.render(address=eip) return template.render(address=eip)
def describe_addresses(self): def describe_addresses(self):
allocation_ids = self._get_multi_param('AllocationId')
public_ips = self._get_multi_param('PublicIp')
filters = filters_from_querystring(self.querystring)
addresses = self.ec2_backend.describe_addresses(
allocation_ids, public_ips, filters)
template = self.response_template(DESCRIBE_ADDRESS_RESPONSE) template = self.response_template(DESCRIBE_ADDRESS_RESPONSE)
if "Filter.1.Name" in self.querystring:
filter_by = sequence_from_querystring(
"Filter.1.Name", self.querystring)[0]
filter_value = sequence_from_querystring(
"Filter.1.Value", self.querystring)
if filter_by == 'instance-id':
addresses = filter(lambda x: x.instance.id == filter_value[
0], self.ec2_backend.describe_addresses())
else:
raise NotImplementedError(
"Filtering not supported in describe_address.")
elif "PublicIp.1" in self.querystring:
public_ips = sequence_from_querystring(
"PublicIp", self.querystring)
addresses = self.ec2_backend.address_by_ip(public_ips)
elif "AllocationId.1" in self.querystring:
allocation_ids = sequence_from_querystring(
"AllocationId", self.querystring)
addresses = self.ec2_backend.address_by_allocation(allocation_ids)
else:
addresses = self.ec2_backend.describe_addresses()
return template.render(addresses=addresses) return template.render(addresses=addresses)
def disassociate_address(self): def disassociate_address(self):
if self.is_not_dryrun('DisAssociateAddress'): if self.is_not_dryrun('DisAssociateAddress'):
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
self.ec2_backend.disassociate_address( self.ec2_backend.disassociate_address(
address=self.querystring['PublicIp'][0]) address=self._get_param('PublicIp'))
elif "AssociationId" in self.querystring: elif "AssociationId" in self.querystring:
self.ec2_backend.disassociate_address( self.ec2_backend.disassociate_address(
association_id=self.querystring['AssociationId'][0]) association_id=self._get_param('AssociationId'))
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.") "MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.")
@ -94,10 +76,10 @@ class ElasticIPAddresses(BaseResponse):
if self.is_not_dryrun('ReleaseAddress'): if self.is_not_dryrun('ReleaseAddress'):
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
self.ec2_backend.release_address( self.ec2_backend.release_address(
address=self.querystring['PublicIp'][0]) address=self._get_param('PublicIp'))
elif "AllocationId" in self.querystring: elif "AllocationId" in self.querystring:
self.ec2_backend.release_address( self.ec2_backend.release_address(
allocation_id=self.querystring['AllocationId'][0]) allocation_id=self._get_param('AllocationId'))
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") "MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.")

View File

@ -1,15 +1,14 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import sequence_from_querystring, filters_from_querystring from moto.ec2.utils import filters_from_querystring
class ElasticNetworkInterfaces(BaseResponse): class ElasticNetworkInterfaces(BaseResponse):
def create_network_interface(self): def create_network_interface(self):
subnet_id = self.querystring.get('SubnetId')[0] subnet_id = self._get_param('SubnetId')
private_ip_address = self.querystring.get( private_ip_address = self._get_param('PrivateIpAddress')
'PrivateIpAddress', [None])[0] groups = self._get_multi_param('SecurityGroupId')
groups = sequence_from_querystring('SecurityGroupId', self.querystring)
subnet = self.ec2_backend.get_subnet(subnet_id) subnet = self.ec2_backend.get_subnet(subnet_id)
if self.is_not_dryrun('CreateNetworkInterface'): if self.is_not_dryrun('CreateNetworkInterface'):
eni = self.ec2_backend.create_network_interface( eni = self.ec2_backend.create_network_interface(
@ -19,7 +18,7 @@ class ElasticNetworkInterfaces(BaseResponse):
return template.render(eni=eni) return template.render(eni=eni)
def delete_network_interface(self): def delete_network_interface(self):
eni_id = self.querystring.get('NetworkInterfaceId')[0] eni_id = self._get_param('NetworkInterfaceId')
if self.is_not_dryrun('DeleteNetworkInterface'): if self.is_not_dryrun('DeleteNetworkInterface'):
self.ec2_backend.delete_network_interface(eni_id) self.ec2_backend.delete_network_interface(eni_id)
template = self.response_template( template = self.response_template(
@ -31,17 +30,16 @@ class ElasticNetworkInterfaces(BaseResponse):
'ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented') 'ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented')
def describe_network_interfaces(self): def describe_network_interfaces(self):
eni_ids = sequence_from_querystring( eni_ids = self._get_multi_param('NetworkInterfaceId')
'NetworkInterfaceId', self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters) enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters)
template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE) template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE)
return template.render(enis=enis) return template.render(enis=enis)
def attach_network_interface(self): def attach_network_interface(self):
eni_id = self.querystring.get('NetworkInterfaceId')[0] eni_id = self._get_param('NetworkInterfaceId')
instance_id = self.querystring.get('InstanceId')[0] instance_id = self._get_param('InstanceId')
device_index = self.querystring.get('DeviceIndex')[0] device_index = self._get_param('DeviceIndex')
if self.is_not_dryrun('AttachNetworkInterface'): if self.is_not_dryrun('AttachNetworkInterface'):
attachment_id = self.ec2_backend.attach_network_interface( attachment_id = self.ec2_backend.attach_network_interface(
eni_id, instance_id, device_index) eni_id, instance_id, device_index)
@ -50,7 +48,7 @@ class ElasticNetworkInterfaces(BaseResponse):
return template.render(attachment_id=attachment_id) return template.render(attachment_id=attachment_id)
def detach_network_interface(self): def detach_network_interface(self):
attachment_id = self.querystring.get('AttachmentId')[0] attachment_id = self._get_param('AttachmentId')
if self.is_not_dryrun('DetachNetworkInterface'): if self.is_not_dryrun('DetachNetworkInterface'):
self.ec2_backend.detach_network_interface(attachment_id) self.ec2_backend.detach_network_interface(attachment_id)
template = self.response_template( template = self.response_template(
@ -59,8 +57,8 @@ class ElasticNetworkInterfaces(BaseResponse):
def modify_network_interface_attribute(self): def modify_network_interface_attribute(self):
# Currently supports modifying one and only one security group # Currently supports modifying one and only one security group
eni_id = self.querystring.get('NetworkInterfaceId')[0] eni_id = self._get_param('NetworkInterfaceId')
group_id = self.querystring.get('SecurityGroupId.1')[0] group_id = self._get_param('SecurityGroupId.1')
if self.is_not_dryrun('ModifyNetworkInterface'): if self.is_not_dryrun('ModifyNetworkInterface'):
self.ec2_backend.modify_network_interface_attribute( self.ec2_backend.modify_network_interface_attribute(
eni_id, group_id) eni_id, group_id)

View File

@ -1,13 +1,16 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import instance_ids_from_querystring
class General(BaseResponse): class General(BaseResponse):
def get_console_output(self): def get_console_output(self):
self.instance_ids = instance_ids_from_querystring(self.querystring) instance_id = self._get_param('InstanceId')
instance_id = self.instance_ids[0] if not instance_id:
# For compatibility with boto.
# See: https://github.com/spulec/moto/pull/1152#issuecomment-332487599
instance_id = self._get_multi_param('InstanceId')[0]
instance = self.ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
template = self.response_template(GET_CONSOLE_OUTPUT_RESULT) template = self.response_template(GET_CONSOLE_OUTPUT_RESULT)
return template.render(instance=instance) return template.render(instance=instance)

View File

@ -2,15 +2,15 @@ from __future__ import unicode_literals
from boto.ec2.instancetype import InstanceType from boto.ec2.instancetype import InstanceType
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import instance_ids_from_querystring, filters_from_querystring, \ from moto.ec2.utils import filters_from_querystring, \
dict_from_querystring, optional_from_querystring dict_from_querystring
class InstanceResponse(BaseResponse): class InstanceResponse(BaseResponse):
def describe_instances(self): def describe_instances(self):
filter_dict = filters_from_querystring(self.querystring) filter_dict = filters_from_querystring(self.querystring)
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = self._get_multi_param('InstanceId')
token = self._get_param("NextToken") token = self._get_param("NextToken")
if instance_ids: if instance_ids:
reservations = self.ec2_backend.get_reservations_by_instance_ids( reservations = self.ec2_backend.get_reservations_by_instance_ids(
@ -30,29 +30,28 @@ class InstanceResponse(BaseResponse):
if max_results and len(reservations) > (start + max_results): if max_results and len(reservations) > (start + max_results):
next_token = reservations_resp[-1].id next_token = reservations_resp[-1].id
template = self.response_template(EC2_DESCRIBE_INSTANCES) template = self.response_template(EC2_DESCRIBE_INSTANCES)
return template.render(reservations=reservations_resp, next_token=next_token) return template.render(reservations=reservations_resp, next_token=next_token).replace('True', 'true').replace('False', 'false')
def run_instances(self): def run_instances(self):
min_count = int(self.querystring.get('MinCount', ['1'])[0]) min_count = int(self._get_param('MinCount', if_none='1'))
image_id = self.querystring.get('ImageId')[0] image_id = self._get_param('ImageId')
user_data = self.querystring.get('UserData') user_data = self._get_param('UserData')
security_group_names = self._get_multi_param('SecurityGroup') security_group_names = self._get_multi_param('SecurityGroup')
security_group_ids = self._get_multi_param('SecurityGroupId') security_group_ids = self._get_multi_param('SecurityGroupId')
nics = dict_from_querystring("NetworkInterface", self.querystring) nics = dict_from_querystring("NetworkInterface", self.querystring)
instance_type = self.querystring.get("InstanceType", ["m1.small"])[0] instance_type = self._get_param('InstanceType', if_none='m1.small')
placement = self.querystring.get( placement = self._get_param('Placement.AvailabilityZone')
"Placement.AvailabilityZone", [None])[0] subnet_id = self._get_param('SubnetId')
subnet_id = self.querystring.get("SubnetId", [None])[0] private_ip = self._get_param('PrivateIpAddress')
private_ip = self.querystring.get("PrivateIpAddress", [None])[0] associate_public_ip = self._get_param('AssociatePublicIpAddress')
associate_public_ip = self.querystring.get( key_name = self._get_param('KeyName')
"AssociatePublicIpAddress", [None])[0]
key_name = self.querystring.get("KeyName", [None])[0]
tags = self._parse_tag_specification("TagSpecification") tags = self._parse_tag_specification("TagSpecification")
region_name = self.region
if self.is_not_dryrun('RunInstance'): if self.is_not_dryrun('RunInstance'):
new_reservation = self.ec2_backend.add_instances( new_reservation = self.ec2_backend.add_instances(
image_id, min_count, user_data, security_group_names, image_id, min_count, user_data, security_group_names,
instance_type=instance_type, placement=placement, subnet_id=subnet_id, instance_type=instance_type, placement=placement, region_name=region_name, subnet_id=subnet_id,
key_name=key_name, security_group_ids=security_group_ids, key_name=key_name, security_group_ids=security_group_ids,
nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip, nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip,
tags=tags) tags=tags)
@ -61,37 +60,36 @@ class InstanceResponse(BaseResponse):
return template.render(reservation=new_reservation) return template.render(reservation=new_reservation)
def terminate_instances(self): def terminate_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('TerminateInstance'): if self.is_not_dryrun('TerminateInstance'):
instances = self.ec2_backend.terminate_instances(instance_ids) instances = self.ec2_backend.terminate_instances(instance_ids)
template = self.response_template(EC2_TERMINATE_INSTANCES) template = self.response_template(EC2_TERMINATE_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def reboot_instances(self): def reboot_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('RebootInstance'): if self.is_not_dryrun('RebootInstance'):
instances = self.ec2_backend.reboot_instances(instance_ids) instances = self.ec2_backend.reboot_instances(instance_ids)
template = self.response_template(EC2_REBOOT_INSTANCES) template = self.response_template(EC2_REBOOT_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def stop_instances(self): def stop_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('StopInstance'): if self.is_not_dryrun('StopInstance'):
instances = self.ec2_backend.stop_instances(instance_ids) instances = self.ec2_backend.stop_instances(instance_ids)
template = self.response_template(EC2_STOP_INSTANCES) template = self.response_template(EC2_STOP_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def start_instances(self): def start_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('StartInstance'): if self.is_not_dryrun('StartInstance'):
instances = self.ec2_backend.start_instances(instance_ids) instances = self.ec2_backend.start_instances(instance_ids)
template = self.response_template(EC2_START_INSTANCES) template = self.response_template(EC2_START_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def describe_instance_status(self): def describe_instance_status(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = self._get_multi_param('InstanceId')
include_all_instances = optional_from_querystring('IncludeAllInstances', include_all_instances = self._get_param('IncludeAllInstances') == 'true'
self.querystring) == 'true'
if instance_ids: if instance_ids:
instances = self.ec2_backend.get_multi_instances_by_id( instances = self.ec2_backend.get_multi_instances_by_id(
@ -113,10 +111,9 @@ class InstanceResponse(BaseResponse):
def describe_instance_attribute(self): def describe_instance_attribute(self):
# TODO this and modify below should raise IncorrectInstanceState if # TODO this and modify below should raise IncorrectInstanceState if
# instance not in stopped state # instance not in stopped state
attribute = self.querystring.get("Attribute")[0] attribute = self._get_param('Attribute')
key = camelcase_to_underscores(attribute) key = camelcase_to_underscores(attribute)
instance_ids = instance_ids_from_querystring(self.querystring) instance_id = self._get_param('InstanceId')
instance_id = instance_ids[0]
instance, value = self.ec2_backend.describe_instance_attribute( instance, value = self.ec2_backend.describe_instance_attribute(
instance_id, key) instance_id, key)
@ -147,7 +144,12 @@ class InstanceResponse(BaseResponse):
""" """
Handles requests which are generated by code similar to: Handles requests which are generated by code similar to:
instance.modify_attribute('blockDeviceMapping', {'/dev/sda1': True}) instance.modify_attribute(
BlockDeviceMappings=[{
'DeviceName': '/dev/sda1',
'Ebs': {'DeleteOnTermination': True}
}]
)
The querystring contains information similar to: The querystring contains information similar to:
@ -170,8 +172,7 @@ class InstanceResponse(BaseResponse):
del_on_term_value = True if 'true' == del_on_term_value_str else False del_on_term_value = True if 'true' == del_on_term_value_str else False
device_name_value = self.querystring[mapping_device_name][0] device_name_value = self.querystring[mapping_device_name][0]
instance_ids = instance_ids_from_querystring(self.querystring) instance_id = self._get_param('InstanceId')
instance_id = instance_ids[0]
instance = self.ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
if self.is_not_dryrun('ModifyInstanceAttribute'): if self.is_not_dryrun('ModifyInstanceAttribute'):
@ -199,8 +200,7 @@ class InstanceResponse(BaseResponse):
value = self.querystring.get(attribute_key)[0] value = self.querystring.get(attribute_key)[0]
normalized_attribute = camelcase_to_underscores( normalized_attribute = camelcase_to_underscores(
attribute_key.split(".")[0]) attribute_key.split(".")[0])
instance_ids = instance_ids_from_querystring(self.querystring) instance_id = self._get_param('InstanceId')
instance_id = instance_ids[0]
self.ec2_backend.modify_instance_attribute( self.ec2_backend.modify_instance_attribute(
instance_id, normalized_attribute, value) instance_id, normalized_attribute, value)
return EC2_MODIFY_INSTANCE_ATTRIBUTE return EC2_MODIFY_INSTANCE_ATTRIBUTE
@ -211,8 +211,7 @@ class InstanceResponse(BaseResponse):
if 'GroupId.' in key: if 'GroupId.' in key:
new_security_grp_list.append(self.querystring.get(key)[0]) new_security_grp_list.append(self.querystring.get(key)[0])
instance_ids = instance_ids_from_querystring(self.querystring) instance_id = self._get_param('InstanceId')
instance_id = instance_ids[0]
if self.is_not_dryrun('ModifyInstanceSecurityGroups'): if self.is_not_dryrun('ModifyInstanceSecurityGroups'):
self.ec2_backend.modify_instance_security_groups( self.ec2_backend.modify_instance_security_groups(
instance_id, new_security_grp_list) instance_id, new_security_grp_list)
@ -254,17 +253,19 @@ EC2_RUN_INSTANCES = """<RunInstancesResponse xmlns="http://ec2.amazonaws.com/doc
<monitoring> <monitoring>
<state>enabled</state> <state>enabled</state>
</monitoring> </monitoring>
{% if instance.nics %} {% if instance.subnet_id %}
{% if instance.nics[0].subnet %}
<subnetId>{{ instance.nics[0].subnet.id }}</subnetId>
<vpcId>{{ instance.nics[0].subnet.vpc_id }}</vpcId>
{% endif %}
<privateIpAddress>{{ instance.private_ip }}</privateIpAddress>
{% if instance.public_ip %}
<ipAddress>{{ instance.public_ip }}</ipAddress>
{% endif %}
{% else %}
<subnetId>{{ instance.subnet_id }}</subnetId> <subnetId>{{ instance.subnet_id }}</subnetId>
{% elif instance.nics[0].subnet.id %}
<subnetId>{{ instance.nics[0].subnet.id }}</subnetId>
{% endif %}
{% if instance.vpc_id %}
<vpcId>{{ instance.vpc_id }}</vpcId>
{% elif instance.nics[0].subnet.vpc_id %}
<vpcId>{{ instance.nics[0].subnet.vpc_id }}</vpcId>
{% endif %}
<privateIpAddress>{{ instance.private_ip }}</privateIpAddress>
{% if instance.nics[0].public_ip %}
<ipAddress>{{ instance.nics[0].public_ip }}</ipAddress>
{% endif %} {% endif %}
<sourceDestCheck>{{ instance.source_dest_check }}</sourceDestCheck> <sourceDestCheck>{{ instance.source_dest_check }}</sourceDestCheck>
<groupSet> <groupSet>
@ -395,26 +396,30 @@ EC2_DESCRIBE_INSTANCES = """<DescribeInstancesResponse xmlns="http://ec2.amazona
<monitoring> <monitoring>
<state>disabled</state> <state>disabled</state>
</monitoring> </monitoring>
{% if instance.nics %} {% if instance.subnet_id %}
{% if instance.nics[0].subnet %} <subnetId>{{ instance.subnet_id }}</subnetId>
<subnetId>{{ instance.nics[0].subnet.id }}</subnetId> {% elif instance.nics[0].subnet.id %}
<vpcId>{{ instance.nics[0].subnet.vpc_id }}</vpcId> <subnetId>{{ instance.nics[0].subnet.id }}</subnetId>
{% endif %} {% endif %}
<privateIpAddress>{{ instance.private_ip }}</privateIpAddress> {% if instance.vpc_id %}
{% if instance.nics[0].public_ip %} <vpcId>{{ instance.vpc_id }}</vpcId>
<ipAddress>{{ instance.nics[0].public_ip }}</ipAddress> {% elif instance.nics[0].subnet.vpc_id %}
{% endif %} <vpcId>{{ instance.nics[0].subnet.vpc_id }}</vpcId>
{% endif %}
<privateIpAddress>{{ instance.private_ip }}</privateIpAddress>
{% if instance.nics[0].public_ip %}
<ipAddress>{{ instance.nics[0].public_ip }}</ipAddress>
{% endif %} {% endif %}
<sourceDestCheck>{{ instance.source_dest_check }}</sourceDestCheck> <sourceDestCheck>{{ instance.source_dest_check }}</sourceDestCheck>
<groupSet> <groupSet>
{% for group in instance.dynamic_group_list %} {% for group in instance.dynamic_group_list %}
<item> <item>
{% if group.id %} {% if group.id %}
<groupId>{{ group.id }}</groupId> <groupId>{{ group.id }}</groupId>
<groupName>{{ group.name }}</groupName> <groupName>{{ group.name }}</groupName>
{% else %} {% else %}
<groupId>{{ group }}</groupId> <groupId>{{ group }}</groupId>
{% endif %} {% endif %}
</item> </item>
{% endfor %} {% endfor %}
</groupSet> </groupSet>

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import ( from moto.ec2.utils import (
sequence_from_querystring,
filters_from_querystring, filters_from_querystring,
) )
@ -9,8 +8,8 @@ from moto.ec2.utils import (
class InternetGateways(BaseResponse): class InternetGateways(BaseResponse):
def attach_internet_gateway(self): def attach_internet_gateway(self):
igw_id = self.querystring.get("InternetGatewayId", [None])[0] igw_id = self._get_param('InternetGatewayId')
vpc_id = self.querystring.get("VpcId", [None])[0] vpc_id = self._get_param('VpcId')
if self.is_not_dryrun('AttachInternetGateway'): if self.is_not_dryrun('AttachInternetGateway'):
self.ec2_backend.attach_internet_gateway(igw_id, vpc_id) self.ec2_backend.attach_internet_gateway(igw_id, vpc_id)
template = self.response_template(ATTACH_INTERNET_GATEWAY_RESPONSE) template = self.response_template(ATTACH_INTERNET_GATEWAY_RESPONSE)
@ -23,7 +22,7 @@ class InternetGateways(BaseResponse):
return template.render(internet_gateway=igw) return template.render(internet_gateway=igw)
def delete_internet_gateway(self): def delete_internet_gateway(self):
igw_id = self.querystring.get("InternetGatewayId", [None])[0] igw_id = self._get_param('InternetGatewayId')
if self.is_not_dryrun('DeleteInternetGateway'): if self.is_not_dryrun('DeleteInternetGateway'):
self.ec2_backend.delete_internet_gateway(igw_id) self.ec2_backend.delete_internet_gateway(igw_id)
template = self.response_template(DELETE_INTERNET_GATEWAY_RESPONSE) template = self.response_template(DELETE_INTERNET_GATEWAY_RESPONSE)
@ -32,8 +31,7 @@ class InternetGateways(BaseResponse):
def describe_internet_gateways(self): def describe_internet_gateways(self):
filter_dict = filters_from_querystring(self.querystring) filter_dict = filters_from_querystring(self.querystring)
if "InternetGatewayId.1" in self.querystring: if "InternetGatewayId.1" in self.querystring:
igw_ids = sequence_from_querystring( igw_ids = self._get_multi_param("InternetGatewayId")
"InternetGatewayId", self.querystring)
igws = self.ec2_backend.describe_internet_gateways( igws = self.ec2_backend.describe_internet_gateways(
igw_ids, filters=filter_dict) igw_ids, filters=filter_dict)
else: else:
@ -46,8 +44,8 @@ class InternetGateways(BaseResponse):
def detach_internet_gateway(self): def detach_internet_gateway(self):
# TODO validate no instances with EIPs in VPC before detaching # TODO validate no instances with EIPs in VPC before detaching
# raise else DependencyViolationError() # raise else DependencyViolationError()
igw_id = self.querystring.get("InternetGatewayId", [None])[0] igw_id = self._get_param('InternetGatewayId')
vpc_id = self.querystring.get("VpcId", [None])[0] vpc_id = self._get_param('VpcId')
if self.is_not_dryrun('DetachInternetGateway'): if self.is_not_dryrun('DetachInternetGateway'):
self.ec2_backend.detach_internet_gateway(igw_id, vpc_id) self.ec2_backend.detach_internet_gateway(igw_id, vpc_id)
template = self.response_template(DETACH_INTERNET_GATEWAY_RESPONSE) template = self.response_template(DETACH_INTERNET_GATEWAY_RESPONSE)

View File

@ -1,43 +1,39 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import six import six
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import keypair_names_from_querystring, filters_from_querystring from moto.ec2.utils import filters_from_querystring
class KeyPairs(BaseResponse): class KeyPairs(BaseResponse):
def create_key_pair(self): def create_key_pair(self):
name = self.querystring.get('KeyName')[0] name = self._get_param('KeyName')
if self.is_not_dryrun('CreateKeyPair'): if self.is_not_dryrun('CreateKeyPair'):
keypair = self.ec2_backend.create_key_pair(name) keypair = self.ec2_backend.create_key_pair(name)
template = self.response_template(CREATE_KEY_PAIR_RESPONSE) template = self.response_template(CREATE_KEY_PAIR_RESPONSE)
return template.render(**keypair) return template.render(keypair=keypair)
def delete_key_pair(self): def delete_key_pair(self):
name = self.querystring.get('KeyName')[0] name = self._get_param('KeyName')
if self.is_not_dryrun('DeleteKeyPair'): if self.is_not_dryrun('DeleteKeyPair'):
success = six.text_type( success = six.text_type(
self.ec2_backend.delete_key_pair(name)).lower() self.ec2_backend.delete_key_pair(name)).lower()
return self.response_template(DELETE_KEY_PAIR_RESPONSE).render(success=success) return self.response_template(DELETE_KEY_PAIR_RESPONSE).render(success=success)
def describe_key_pairs(self): def describe_key_pairs(self):
names = keypair_names_from_querystring(self.querystring) names = self._get_multi_param('KeyName')
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
if len(filters) > 0: keypairs = self.ec2_backend.describe_key_pairs(names, filters)
raise NotImplementedError(
'Using filters in KeyPairs.describe_key_pairs is not yet implemented')
keypairs = self.ec2_backend.describe_key_pairs(names)
template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE) template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE)
return template.render(keypairs=keypairs) return template.render(keypairs=keypairs)
def import_key_pair(self): def import_key_pair(self):
name = self.querystring.get('KeyName')[0] name = self._get_param('KeyName')
material = self.querystring.get('PublicKeyMaterial')[0] material = self._get_param('PublicKeyMaterial')
if self.is_not_dryrun('ImportKeyPair'): if self.is_not_dryrun('ImportKeyPair'):
keypair = self.ec2_backend.import_key_pair(name, material) keypair = self.ec2_backend.import_key_pair(name, material)
template = self.response_template(IMPORT_KEYPAIR_RESPONSE) template = self.response_template(IMPORT_KEYPAIR_RESPONSE)
return template.render(**keypair) return template.render(keypair=keypair)
DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -54,12 +50,9 @@ DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.ama
CREATE_KEY_PAIR_RESPONSE = """<CreateKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_KEY_PAIR_RESPONSE = """<CreateKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<keyName>{{ name }}</keyName> <keyName>{{ keypair.name }}</keyName>
<keyFingerprint> <keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
{{ fingerprint }} <keyMaterial>{{ keypair.material }}</keyMaterial>
</keyFingerprint>
<keyMaterial>{{ material }}
</keyMaterial>
</CreateKeyPairResponse>""" </CreateKeyPairResponse>"""
@ -71,6 +64,6 @@ DELETE_KEY_PAIR_RESPONSE = """<DeleteKeyPairResponse xmlns="http://ec2.amazonaws
IMPORT_KEYPAIR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?> IMPORT_KEYPAIR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ImportKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <ImportKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>471f9fdd-8fe2-4a84-86b0-bd3d3e350979</requestId> <requestId>471f9fdd-8fe2-4a84-86b0-bd3d3e350979</requestId>
<keyName>{{ name }}</keyName> <keyName>{{ keypair.name }}</keyName>
<keyFingerprint>{{ fingerprint }}</keyFingerprint> <keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
</ImportKeyPairResponse>""" </ImportKeyPairResponse>"""

View File

@ -1,28 +1,27 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring, \ from moto.ec2.utils import filters_from_querystring
network_acl_ids_from_querystring
class NetworkACLs(BaseResponse): class NetworkACLs(BaseResponse):
def create_network_acl(self): def create_network_acl(self):
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self._get_param('VpcId')
network_acl = self.ec2_backend.create_network_acl(vpc_id) network_acl = self.ec2_backend.create_network_acl(vpc_id)
template = self.response_template(CREATE_NETWORK_ACL_RESPONSE) template = self.response_template(CREATE_NETWORK_ACL_RESPONSE)
return template.render(network_acl=network_acl) return template.render(network_acl=network_acl)
def create_network_acl_entry(self): def create_network_acl_entry(self):
network_acl_id = self.querystring.get('NetworkAclId')[0] network_acl_id = self._get_param('NetworkAclId')
rule_number = self.querystring.get('RuleNumber')[0] rule_number = self._get_param('RuleNumber')
protocol = self.querystring.get('Protocol')[0] protocol = self._get_param('Protocol')
rule_action = self.querystring.get('RuleAction')[0] rule_action = self._get_param('RuleAction')
egress = self.querystring.get('Egress')[0] egress = self._get_param('Egress')
cidr_block = self.querystring.get('CidrBlock')[0] cidr_block = self._get_param('CidrBlock')
icmp_code = self.querystring.get('Icmp.Code', [None])[0] icmp_code = self._get_param('Icmp.Code')
icmp_type = self.querystring.get('Icmp.Type', [None])[0] icmp_type = self._get_param('Icmp.Type')
port_range_from = self.querystring.get('PortRange.From')[0] port_range_from = self._get_param('PortRange.From')
port_range_to = self.querystring.get('PortRange.To')[0] port_range_to = self._get_param('PortRange.To')
network_acl_entry = self.ec2_backend.create_network_acl_entry( network_acl_entry = self.ec2_backend.create_network_acl_entry(
network_acl_id, rule_number, protocol, rule_action, network_acl_id, rule_number, protocol, rule_action,
@ -33,30 +32,30 @@ class NetworkACLs(BaseResponse):
return template.render(network_acl_entry=network_acl_entry) return template.render(network_acl_entry=network_acl_entry)
def delete_network_acl(self): def delete_network_acl(self):
network_acl_id = self.querystring.get('NetworkAclId')[0] network_acl_id = self._get_param('NetworkAclId')
self.ec2_backend.delete_network_acl(network_acl_id) self.ec2_backend.delete_network_acl(network_acl_id)
template = self.response_template(DELETE_NETWORK_ACL_ASSOCIATION) template = self.response_template(DELETE_NETWORK_ACL_ASSOCIATION)
return template.render() return template.render()
def delete_network_acl_entry(self): def delete_network_acl_entry(self):
network_acl_id = self.querystring.get('NetworkAclId')[0] network_acl_id = self._get_param('NetworkAclId')
rule_number = self.querystring.get('RuleNumber')[0] rule_number = self._get_param('RuleNumber')
egress = self.querystring.get('Egress')[0] egress = self._get_param('Egress')
self.ec2_backend.delete_network_acl_entry(network_acl_id, rule_number, egress) self.ec2_backend.delete_network_acl_entry(network_acl_id, rule_number, egress)
template = self.response_template(DELETE_NETWORK_ACL_ENTRY_RESPONSE) template = self.response_template(DELETE_NETWORK_ACL_ENTRY_RESPONSE)
return template.render() return template.render()
def replace_network_acl_entry(self): def replace_network_acl_entry(self):
network_acl_id = self.querystring.get('NetworkAclId')[0] network_acl_id = self._get_param('NetworkAclId')
rule_number = self.querystring.get('RuleNumber')[0] rule_number = self._get_param('RuleNumber')
protocol = self.querystring.get('Protocol')[0] protocol = self._get_param('Protocol')
rule_action = self.querystring.get('RuleAction')[0] rule_action = self._get_param('RuleAction')
egress = self.querystring.get('Egress')[0] egress = self._get_param('Egress')
cidr_block = self.querystring.get('CidrBlock')[0] cidr_block = self._get_param('CidrBlock')
icmp_code = self.querystring.get('Icmp.Code', [None])[0] icmp_code = self._get_param('Icmp.Code')
icmp_type = self.querystring.get('Icmp.Type', [None])[0] icmp_type = self._get_param('Icmp.Type')
port_range_from = self.querystring.get('PortRange.From')[0] port_range_from = self._get_param('PortRange.From')
port_range_to = self.querystring.get('PortRange.To')[0] port_range_to = self._get_param('PortRange.To')
self.ec2_backend.replace_network_acl_entry( self.ec2_backend.replace_network_acl_entry(
network_acl_id, rule_number, protocol, rule_action, network_acl_id, rule_number, protocol, rule_action,
@ -67,7 +66,7 @@ class NetworkACLs(BaseResponse):
return template.render() return template.render()
def describe_network_acls(self): def describe_network_acls(self):
network_acl_ids = network_acl_ids_from_querystring(self.querystring) network_acl_ids = self._get_multi_param('NetworkAclId')
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
network_acls = self.ec2_backend.get_all_network_acls( network_acls = self.ec2_backend.get_all_network_acls(
network_acl_ids, filters) network_acl_ids, filters)
@ -75,8 +74,8 @@ class NetworkACLs(BaseResponse):
return template.render(network_acls=network_acls) return template.render(network_acls=network_acls)
def replace_network_acl_association(self): def replace_network_acl_association(self):
association_id = self.querystring.get('AssociationId')[0] association_id = self._get_param('AssociationId')
network_acl_id = self.querystring.get('NetworkAclId')[0] network_acl_id = self._get_param('NetworkAclId')
association = self.ec2_backend.replace_network_acl_association( association = self.ec2_backend.replace_network_acl_association(
association_id, association_id,

View File

@ -1,29 +1,25 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import route_table_ids_from_querystring, filters_from_querystring, optional_from_querystring from moto.ec2.utils import filters_from_querystring
class RouteTables(BaseResponse): class RouteTables(BaseResponse):
def associate_route_table(self): def associate_route_table(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self._get_param('RouteTableId')
subnet_id = self.querystring.get('SubnetId')[0] subnet_id = self._get_param('SubnetId')
association_id = self.ec2_backend.associate_route_table( association_id = self.ec2_backend.associate_route_table(
route_table_id, subnet_id) route_table_id, subnet_id)
template = self.response_template(ASSOCIATE_ROUTE_TABLE_RESPONSE) template = self.response_template(ASSOCIATE_ROUTE_TABLE_RESPONSE)
return template.render(association_id=association_id) return template.render(association_id=association_id)
def create_route(self): def create_route(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self._get_param('RouteTableId')
destination_cidr_block = self.querystring.get( destination_cidr_block = self._get_param('DestinationCidrBlock')
'DestinationCidrBlock')[0] gateway_id = self._get_param('GatewayId')
instance_id = self._get_param('InstanceId')
gateway_id = optional_from_querystring('GatewayId', self.querystring) interface_id = self._get_param('NetworkInterfaceId')
instance_id = optional_from_querystring('InstanceId', self.querystring) pcx_id = self._get_param('VpcPeeringConnectionId')
interface_id = optional_from_querystring(
'NetworkInterfaceId', self.querystring)
pcx_id = optional_from_querystring(
'VpcPeeringConnectionId', self.querystring)
self.ec2_backend.create_route(route_table_id, destination_cidr_block, self.ec2_backend.create_route(route_table_id, destination_cidr_block,
gateway_id=gateway_id, gateway_id=gateway_id,
@ -35,27 +31,26 @@ class RouteTables(BaseResponse):
return template.render() return template.render()
def create_route_table(self): def create_route_table(self):
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self._get_param('VpcId')
route_table = self.ec2_backend.create_route_table(vpc_id) route_table = self.ec2_backend.create_route_table(vpc_id)
template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE) template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE)
return template.render(route_table=route_table) return template.render(route_table=route_table)
def delete_route(self): def delete_route(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self._get_param('RouteTableId')
destination_cidr_block = self.querystring.get( destination_cidr_block = self._get_param('DestinationCidrBlock')
'DestinationCidrBlock')[0]
self.ec2_backend.delete_route(route_table_id, destination_cidr_block) self.ec2_backend.delete_route(route_table_id, destination_cidr_block)
template = self.response_template(DELETE_ROUTE_RESPONSE) template = self.response_template(DELETE_ROUTE_RESPONSE)
return template.render() return template.render()
def delete_route_table(self): def delete_route_table(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self._get_param('RouteTableId')
self.ec2_backend.delete_route_table(route_table_id) self.ec2_backend.delete_route_table(route_table_id)
template = self.response_template(DELETE_ROUTE_TABLE_RESPONSE) template = self.response_template(DELETE_ROUTE_TABLE_RESPONSE)
return template.render() return template.render()
def describe_route_tables(self): def describe_route_tables(self):
route_table_ids = route_table_ids_from_querystring(self.querystring) route_table_ids = self._get_multi_param('RouteTableId')
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
route_tables = self.ec2_backend.get_all_route_tables( route_tables = self.ec2_backend.get_all_route_tables(
route_table_ids, filters) route_table_ids, filters)
@ -63,22 +58,18 @@ class RouteTables(BaseResponse):
return template.render(route_tables=route_tables) return template.render(route_tables=route_tables)
def disassociate_route_table(self): def disassociate_route_table(self):
association_id = self.querystring.get('AssociationId')[0] association_id = self._get_param('AssociationId')
self.ec2_backend.disassociate_route_table(association_id) self.ec2_backend.disassociate_route_table(association_id)
template = self.response_template(DISASSOCIATE_ROUTE_TABLE_RESPONSE) template = self.response_template(DISASSOCIATE_ROUTE_TABLE_RESPONSE)
return template.render() return template.render()
def replace_route(self): def replace_route(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self._get_param('RouteTableId')
destination_cidr_block = self.querystring.get( destination_cidr_block = self._get_param('DestinationCidrBlock')
'DestinationCidrBlock')[0] gateway_id = self._get_param('GatewayId')
instance_id = self._get_param('InstanceId')
gateway_id = optional_from_querystring('GatewayId', self.querystring) interface_id = self._get_param('NetworkInterfaceId')
instance_id = optional_from_querystring('InstanceId', self.querystring) pcx_id = self._get_param('VpcPeeringConnectionId')
interface_id = optional_from_querystring(
'NetworkInterfaceId', self.querystring)
pcx_id = optional_from_querystring(
'VpcPeeringConnectionId', self.querystring)
self.ec2_backend.replace_route(route_table_id, destination_cidr_block, self.ec2_backend.replace_route(route_table_id, destination_cidr_block,
gateway_id=gateway_id, gateway_id=gateway_id,
@ -90,8 +81,8 @@ class RouteTables(BaseResponse):
return template.render() return template.render()
def replace_route_table_association(self): def replace_route_table_association(self):
route_table_id = self.querystring.get('RouteTableId')[0] route_table_id = self._get_param('RouteTableId')
association_id = self.querystring.get('AssociationId')[0] association_id = self._get_param('AssociationId')
new_association_id = self.ec2_backend.replace_route_table_association( new_association_id = self.ec2_backend.replace_route_table_association(
association_id, route_table_id) association_id, route_table_id)
template = self.response_template( template = self.response_template(

View File

@ -11,69 +11,66 @@ def try_parse_int(value, default=None):
return default return default
def process_rules_from_querystring(querystring):
try:
group_name_or_id = querystring.get('GroupName')[0]
except:
group_name_or_id = querystring.get('GroupId')[0]
querytree = {}
for key, value in querystring.items():
key_splitted = key.split('.')
key_splitted = [try_parse_int(e, e) for e in key_splitted]
d = querytree
for subkey in key_splitted[:-1]:
if subkey not in d:
d[subkey] = {}
d = d[subkey]
d[key_splitted[-1]] = value
ip_permissions = querytree.get('IpPermissions') or {}
for ip_permission_idx in sorted(ip_permissions.keys()):
ip_permission = ip_permissions[ip_permission_idx]
ip_protocol = ip_permission.get('IpProtocol', [None])[0]
from_port = ip_permission.get('FromPort', [None])[0]
to_port = ip_permission.get('ToPort', [None])[0]
ip_ranges = []
ip_ranges_tree = ip_permission.get('IpRanges') or {}
for ip_range_idx in sorted(ip_ranges_tree.keys()):
ip_ranges.append(ip_ranges_tree[ip_range_idx]['CidrIp'][0])
source_groups = []
source_group_ids = []
groups_tree = ip_permission.get('Groups') or {}
for group_idx in sorted(groups_tree.keys()):
group_dict = groups_tree[group_idx]
if 'GroupId' in group_dict:
source_group_ids.append(group_dict['GroupId'][0])
elif 'GroupName' in group_dict:
source_groups.append(group_dict['GroupName'][0])
yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges,
source_groups, source_group_ids)
class SecurityGroups(BaseResponse): class SecurityGroups(BaseResponse):
def _process_rules_from_querystring(self):
group_name_or_id = (self._get_param('GroupName') or
self._get_param('GroupId'))
querytree = {}
for key, value in self.querystring.items():
key_splitted = key.split('.')
key_splitted = [try_parse_int(e, e) for e in key_splitted]
d = querytree
for subkey in key_splitted[:-1]:
if subkey not in d:
d[subkey] = {}
d = d[subkey]
d[key_splitted[-1]] = value
ip_permissions = querytree.get('IpPermissions') or {}
for ip_permission_idx in sorted(ip_permissions.keys()):
ip_permission = ip_permissions[ip_permission_idx]
ip_protocol = ip_permission.get('IpProtocol', [None])[0]
from_port = ip_permission.get('FromPort', [None])[0]
to_port = ip_permission.get('ToPort', [None])[0]
ip_ranges = []
ip_ranges_tree = ip_permission.get('IpRanges') or {}
for ip_range_idx in sorted(ip_ranges_tree.keys()):
ip_ranges.append(ip_ranges_tree[ip_range_idx]['CidrIp'][0])
source_groups = []
source_group_ids = []
groups_tree = ip_permission.get('Groups') or {}
for group_idx in sorted(groups_tree.keys()):
group_dict = groups_tree[group_idx]
if 'GroupId' in group_dict:
source_group_ids.append(group_dict['GroupId'][0])
elif 'GroupName' in group_dict:
source_groups.append(group_dict['GroupName'][0])
yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges,
source_groups, source_group_ids)
def authorize_security_group_egress(self): def authorize_security_group_egress(self):
if self.is_not_dryrun('GrantSecurityGroupEgress'): if self.is_not_dryrun('GrantSecurityGroupEgress'):
for args in process_rules_from_querystring(self.querystring): for args in self._process_rules_from_querystring():
self.ec2_backend.authorize_security_group_egress(*args) self.ec2_backend.authorize_security_group_egress(*args)
return AUTHORIZE_SECURITY_GROUP_EGRESS_RESPONSE return AUTHORIZE_SECURITY_GROUP_EGRESS_RESPONSE
def authorize_security_group_ingress(self): def authorize_security_group_ingress(self):
if self.is_not_dryrun('GrantSecurityGroupIngress'): if self.is_not_dryrun('GrantSecurityGroupIngress'):
for args in process_rules_from_querystring(self.querystring): for args in self._process_rules_from_querystring():
self.ec2_backend.authorize_security_group_ingress(*args) self.ec2_backend.authorize_security_group_ingress(*args)
return AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE return AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE
def create_security_group(self): def create_security_group(self):
name = self.querystring.get('GroupName')[0] name = self._get_param('GroupName')
description = self.querystring.get('GroupDescription', [None])[0] description = self._get_param('GroupDescription')
vpc_id = self.querystring.get("VpcId", [None])[0] vpc_id = self._get_param('VpcId')
if self.is_not_dryrun('CreateSecurityGroup'): if self.is_not_dryrun('CreateSecurityGroup'):
group = self.ec2_backend.create_security_group( group = self.ec2_backend.create_security_group(
@ -86,14 +83,14 @@ class SecurityGroups(BaseResponse):
# See # See
# http://docs.aws.amazon.com/AWSEC2/latest/APIReference/ApiReference-query-DeleteSecurityGroup.html # http://docs.aws.amazon.com/AWSEC2/latest/APIReference/ApiReference-query-DeleteSecurityGroup.html
name = self.querystring.get('GroupName') name = self._get_param('GroupName')
sg_id = self.querystring.get('GroupId') sg_id = self._get_param('GroupId')
if self.is_not_dryrun('DeleteSecurityGroup'): if self.is_not_dryrun('DeleteSecurityGroup'):
if name: if name:
self.ec2_backend.delete_security_group(name[0]) self.ec2_backend.delete_security_group(name)
elif sg_id: elif sg_id:
self.ec2_backend.delete_security_group(group_id=sg_id[0]) self.ec2_backend.delete_security_group(group_id=sg_id)
return DELETE_GROUP_RESPONSE return DELETE_GROUP_RESPONSE
@ -113,7 +110,7 @@ class SecurityGroups(BaseResponse):
def revoke_security_group_egress(self): def revoke_security_group_egress(self):
if self.is_not_dryrun('RevokeSecurityGroupEgress'): if self.is_not_dryrun('RevokeSecurityGroupEgress'):
for args in process_rules_from_querystring(self.querystring): for args in self._process_rules_from_querystring():
success = self.ec2_backend.revoke_security_group_egress(*args) success = self.ec2_backend.revoke_security_group_egress(*args)
if not success: if not success:
return "Could not find a matching egress rule", dict(status=404) return "Could not find a matching egress rule", dict(status=404)
@ -121,7 +118,7 @@ class SecurityGroups(BaseResponse):
def revoke_security_group_ingress(self): def revoke_security_group_ingress(self):
if self.is_not_dryrun('RevokeSecurityGroupIngress'): if self.is_not_dryrun('RevokeSecurityGroupIngress'):
for args in process_rules_from_querystring(self.querystring): for args in self._process_rules_from_querystring():
self.ec2_backend.revoke_security_group_ingress(*args) self.ec2_backend.revoke_security_group_ingress(*args)
return REVOKE_SECURITY_GROUP_INGRESS_REPONSE return REVOKE_SECURITY_GROUP_INGRESS_REPONSE

View File

@ -29,6 +29,15 @@ class SpotFleets(BaseResponse):
template = self.response_template(DESCRIBE_SPOT_FLEET_TEMPLATE) template = self.response_template(DESCRIBE_SPOT_FLEET_TEMPLATE)
return template.render(requests=requests) return template.render(requests=requests)
def modify_spot_fleet_request(self):
spot_fleet_request_id = self._get_param("SpotFleetRequestId")
target_capacity = self._get_int_param("TargetCapacity")
terminate_instances = self._get_param("ExcessCapacityTerminationPolicy", if_none="Default")
successful = self.ec2_backend.modify_spot_fleet_request(
spot_fleet_request_id, target_capacity, terminate_instances)
template = self.response_template(MODIFY_SPOT_FLEET_REQUEST_TEMPLATE)
return template.render(successful=successful)
def request_spot_fleet(self): def request_spot_fleet(self):
spot_config = self._get_dict_param("SpotFleetRequestConfig.") spot_config = self._get_dict_param("SpotFleetRequestConfig.")
spot_price = spot_config['spot_price'] spot_price = spot_config['spot_price']
@ -56,6 +65,11 @@ REQUEST_SPOT_FLEET_TEMPLATE = """<RequestSpotFleetResponse xmlns="http://ec2.ama
<spotFleetRequestId>{{ request.id }}</spotFleetRequestId> <spotFleetRequestId>{{ request.id }}</spotFleetRequestId>
</RequestSpotFleetResponse>""" </RequestSpotFleetResponse>"""
MODIFY_SPOT_FLEET_REQUEST_TEMPLATE = """<ModifySpotFleetRequestResponse xmlns="http://ec2.amazonaws.com/doc/2016-09-15/">
<requestId>21681fea-9987-aef3-2121-example</requestId>
<return>{{ 'true' if successful else 'false' }}</return>
</ModifySpotFleetRequestResponse>"""
DESCRIBE_SPOT_FLEET_TEMPLATE = """<DescribeSpotFleetRequestsResponse xmlns="http://ec2.amazonaws.com/doc/2016-09-15/"> DESCRIBE_SPOT_FLEET_TEMPLATE = """<DescribeSpotFleetRequestsResponse xmlns="http://ec2.amazonaws.com/doc/2016-09-15/">
<requestId>4d68a6cc-8f2e-4be1-b425-example</requestId> <requestId>4d68a6cc-8f2e-4be1-b425-example</requestId>
<spotFleetRequestConfigSet> <spotFleetRequestConfigSet>

View File

@ -7,14 +7,11 @@ from moto.ec2.utils import filters_from_querystring
class Subnets(BaseResponse): class Subnets(BaseResponse):
def create_subnet(self): def create_subnet(self):
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self._get_param('VpcId')
cidr_block = self.querystring.get('CidrBlock')[0] cidr_block = self._get_param('CidrBlock')
if 'AvailabilityZone' in self.querystring: availability_zone = self._get_param(
availability_zone = self.querystring['AvailabilityZone'][0] 'AvailabilityZone', if_none=random.choice(
else: self.ec2_backend.describe_availability_zones()).name)
zone = random.choice(
self.ec2_backend.describe_availability_zones())
availability_zone = zone.name
subnet = self.ec2_backend.create_subnet( subnet = self.ec2_backend.create_subnet(
vpc_id, vpc_id,
cidr_block, cidr_block,
@ -24,30 +21,21 @@ class Subnets(BaseResponse):
return template.render(subnet=subnet) return template.render(subnet=subnet)
def delete_subnet(self): def delete_subnet(self):
subnet_id = self.querystring.get('SubnetId')[0] subnet_id = self._get_param('SubnetId')
subnet = self.ec2_backend.delete_subnet(subnet_id) subnet = self.ec2_backend.delete_subnet(subnet_id)
template = self.response_template(DELETE_SUBNET_RESPONSE) template = self.response_template(DELETE_SUBNET_RESPONSE)
return template.render(subnet=subnet) return template.render(subnet=subnet)
def describe_subnets(self): def describe_subnets(self):
subnet_ids = self._get_multi_param('SubnetId')
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
subnet_ids = []
idx = 1
key = 'SubnetId.{0}'.format(idx)
while key in self.querystring:
v = self.querystring[key]
subnet_ids.append(v[0])
idx += 1
key = 'SubnetId.{0}'.format(idx)
subnets = self.ec2_backend.get_all_subnets(subnet_ids, filters) subnets = self.ec2_backend.get_all_subnets(subnet_ids, filters)
template = self.response_template(DESCRIBE_SUBNETS_RESPONSE) template = self.response_template(DESCRIBE_SUBNETS_RESPONSE)
return template.render(subnets=subnets) return template.render(subnets=subnets)
def modify_subnet_attribute(self): def modify_subnet_attribute(self):
subnet_id = self.querystring.get('SubnetId')[0] subnet_id = self._get_param('SubnetId')
map_public_ip = self.querystring.get('MapPublicIpOnLaunch.Value')[0] map_public_ip = self._get_param('MapPublicIpOnLaunch.Value')
self.ec2_backend.modify_subnet_attribute(subnet_id, map_public_ip) self.ec2_backend.modify_subnet_attribute(subnet_id, map_public_ip)
return MODIFY_SUBNET_ATTRIBUTE_RESPONSE return MODIFY_SUBNET_ATTRIBUTE_RESPONSE

View File

@ -2,14 +2,13 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.models import validate_resource_ids from moto.ec2.models import validate_resource_ids
from moto.ec2.utils import sequence_from_querystring, tags_from_query_string, filters_from_querystring from moto.ec2.utils import tags_from_query_string, filters_from_querystring
class TagResponse(BaseResponse): class TagResponse(BaseResponse):
def create_tags(self): def create_tags(self):
resource_ids = sequence_from_querystring( resource_ids = self._get_multi_param('ResourceId')
'ResourceId', self.querystring)
validate_resource_ids(resource_ids) validate_resource_ids(resource_ids)
self.ec2_backend.do_resources_exist(resource_ids) self.ec2_backend.do_resources_exist(resource_ids)
tags = tags_from_query_string(self.querystring) tags = tags_from_query_string(self.querystring)
@ -18,8 +17,7 @@ class TagResponse(BaseResponse):
return CREATE_RESPONSE return CREATE_RESPONSE
def delete_tags(self): def delete_tags(self):
resource_ids = sequence_from_querystring( resource_ids = self._get_multi_param('ResourceId')
'ResourceId', self.querystring)
validate_resource_ids(resource_ids) validate_resource_ids(resource_ids)
tags = tags_from_query_string(self.querystring) tags = tags_from_query_string(self.querystring)
if self.is_not_dryrun('DeleteTags'): if self.is_not_dryrun('DeleteTags'):

View File

@ -6,8 +6,8 @@ from moto.ec2.utils import filters_from_querystring
class VirtualPrivateGateways(BaseResponse): class VirtualPrivateGateways(BaseResponse):
def attach_vpn_gateway(self): def attach_vpn_gateway(self):
vpn_gateway_id = self.querystring.get('VpnGatewayId')[0] vpn_gateway_id = self._get_param('VpnGatewayId')
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self._get_param('VpcId')
attachment = self.ec2_backend.attach_vpn_gateway( attachment = self.ec2_backend.attach_vpn_gateway(
vpn_gateway_id, vpn_gateway_id,
vpc_id vpc_id
@ -16,13 +16,13 @@ class VirtualPrivateGateways(BaseResponse):
return template.render(attachment=attachment) return template.render(attachment=attachment)
def create_vpn_gateway(self): def create_vpn_gateway(self):
type = self.querystring.get('Type', None)[0] type = self._get_param('Type')
vpn_gateway = self.ec2_backend.create_vpn_gateway(type) vpn_gateway = self.ec2_backend.create_vpn_gateway(type)
template = self.response_template(CREATE_VPN_GATEWAY_RESPONSE) template = self.response_template(CREATE_VPN_GATEWAY_RESPONSE)
return template.render(vpn_gateway=vpn_gateway) return template.render(vpn_gateway=vpn_gateway)
def delete_vpn_gateway(self): def delete_vpn_gateway(self):
vpn_gateway_id = self.querystring.get('VpnGatewayId')[0] vpn_gateway_id = self._get_param('VpnGatewayId')
vpn_gateway = self.ec2_backend.delete_vpn_gateway(vpn_gateway_id) vpn_gateway = self.ec2_backend.delete_vpn_gateway(vpn_gateway_id)
template = self.response_template(DELETE_VPN_GATEWAY_RESPONSE) template = self.response_template(DELETE_VPN_GATEWAY_RESPONSE)
return template.render(vpn_gateway=vpn_gateway) return template.render(vpn_gateway=vpn_gateway)
@ -34,8 +34,8 @@ class VirtualPrivateGateways(BaseResponse):
return template.render(vpn_gateways=vpn_gateways) return template.render(vpn_gateways=vpn_gateways)
def detach_vpn_gateway(self): def detach_vpn_gateway(self):
vpn_gateway_id = self.querystring.get('VpnGatewayId')[0] vpn_gateway_id = self._get_param('VpnGatewayId')
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self._get_param('VpcId')
attachment = self.ec2_backend.detach_vpn_gateway( attachment = self.ec2_backend.detach_vpn_gateway(
vpn_gateway_id, vpn_gateway_id,
vpc_id vpc_id

View File

@ -5,16 +5,15 @@ from moto.core.responses import BaseResponse
class VPCPeeringConnections(BaseResponse): class VPCPeeringConnections(BaseResponse):
def create_vpc_peering_connection(self): def create_vpc_peering_connection(self):
vpc = self.ec2_backend.get_vpc(self.querystring.get('VpcId')[0]) vpc = self.ec2_backend.get_vpc(self._get_param('VpcId'))
peer_vpc = self.ec2_backend.get_vpc( peer_vpc = self.ec2_backend.get_vpc(self._get_param('PeerVpcId'))
self.querystring.get('PeerVpcId')[0])
vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc) vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc)
template = self.response_template( template = self.response_template(
CREATE_VPC_PEERING_CONNECTION_RESPONSE) CREATE_VPC_PEERING_CONNECTION_RESPONSE)
return template.render(vpc_pcx=vpc_pcx) return template.render(vpc_pcx=vpc_pcx)
def delete_vpc_peering_connection(self): def delete_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0] vpc_pcx_id = self._get_param('VpcPeeringConnectionId')
vpc_pcx = self.ec2_backend.delete_vpc_peering_connection(vpc_pcx_id) vpc_pcx = self.ec2_backend.delete_vpc_peering_connection(vpc_pcx_id)
template = self.response_template( template = self.response_template(
DELETE_VPC_PEERING_CONNECTION_RESPONSE) DELETE_VPC_PEERING_CONNECTION_RESPONSE)
@ -27,14 +26,14 @@ class VPCPeeringConnections(BaseResponse):
return template.render(vpc_pcxs=vpc_pcxs) return template.render(vpc_pcxs=vpc_pcxs)
def accept_vpc_peering_connection(self): def accept_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0] vpc_pcx_id = self._get_param('VpcPeeringConnectionId')
vpc_pcx = self.ec2_backend.accept_vpc_peering_connection(vpc_pcx_id) vpc_pcx = self.ec2_backend.accept_vpc_peering_connection(vpc_pcx_id)
template = self.response_template( template = self.response_template(
ACCEPT_VPC_PEERING_CONNECTION_RESPONSE) ACCEPT_VPC_PEERING_CONNECTION_RESPONSE)
return template.render(vpc_pcx=vpc_pcx) return template.render(vpc_pcx=vpc_pcx)
def reject_vpc_peering_connection(self): def reject_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0] vpc_pcx_id = self._get_param('VpcPeeringConnectionId')
self.ec2_backend.reject_vpc_peering_connection(vpc_pcx_id) self.ec2_backend.reject_vpc_peering_connection(vpc_pcx_id)
template = self.response_template( template = self.response_template(
REJECT_VPC_PEERING_CONNECTION_RESPONSE) REJECT_VPC_PEERING_CONNECTION_RESPONSE)

View File

@ -1,42 +1,41 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, vpc_ids_from_querystring from moto.ec2.utils import filters_from_querystring
class VPCs(BaseResponse): class VPCs(BaseResponse):
def create_vpc(self): def create_vpc(self):
cidr_block = self.querystring.get('CidrBlock')[0] cidr_block = self._get_param('CidrBlock')
instance_tenancy = self.querystring.get( instance_tenancy = self._get_param('InstanceTenancy', if_none='default')
'InstanceTenancy', ['default'])[0]
vpc = self.ec2_backend.create_vpc(cidr_block, instance_tenancy) vpc = self.ec2_backend.create_vpc(cidr_block, instance_tenancy)
template = self.response_template(CREATE_VPC_RESPONSE) template = self.response_template(CREATE_VPC_RESPONSE)
return template.render(vpc=vpc) return template.render(vpc=vpc)
def delete_vpc(self): def delete_vpc(self):
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self._get_param('VpcId')
vpc = self.ec2_backend.delete_vpc(vpc_id) vpc = self.ec2_backend.delete_vpc(vpc_id)
template = self.response_template(DELETE_VPC_RESPONSE) template = self.response_template(DELETE_VPC_RESPONSE)
return template.render(vpc=vpc) return template.render(vpc=vpc)
def describe_vpcs(self): def describe_vpcs(self):
vpc_ids = vpc_ids_from_querystring(self.querystring) vpc_ids = self._get_multi_param('VpcId')
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
vpcs = self.ec2_backend.get_all_vpcs(vpc_ids=vpc_ids, filters=filters) vpcs = self.ec2_backend.get_all_vpcs(vpc_ids=vpc_ids, filters=filters)
template = self.response_template(DESCRIBE_VPCS_RESPONSE) template = self.response_template(DESCRIBE_VPCS_RESPONSE)
return template.render(vpcs=vpcs) return template.render(vpcs=vpcs)
def describe_vpc_attribute(self): def describe_vpc_attribute(self):
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self._get_param('VpcId')
attribute = self.querystring.get('Attribute')[0] attribute = self._get_param('Attribute')
attr_name = camelcase_to_underscores(attribute) attr_name = camelcase_to_underscores(attribute)
value = self.ec2_backend.describe_vpc_attribute(vpc_id, attr_name) value = self.ec2_backend.describe_vpc_attribute(vpc_id, attr_name)
template = self.response_template(DESCRIBE_VPC_ATTRIBUTE_RESPONSE) template = self.response_template(DESCRIBE_VPC_ATTRIBUTE_RESPONSE)
return template.render(vpc_id=vpc_id, attribute=attribute, value=value) return template.render(vpc_id=vpc_id, attribute=attribute, value=value)
def modify_vpc_attribute(self): def modify_vpc_attribute(self):
vpc_id = self.querystring.get('VpcId')[0] vpc_id = self._get_param('VpcId')
for attribute in ('EnableDnsSupport', 'EnableDnsHostnames'): for attribute in ('EnableDnsSupport', 'EnableDnsHostnames'):
if self.querystring.get('%s.Value' % attribute): if self.querystring.get('%s.Value' % attribute):

View File

@ -1,30 +1,29 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring, sequence_from_querystring from moto.ec2.utils import filters_from_querystring
class VPNConnections(BaseResponse): class VPNConnections(BaseResponse):
def create_vpn_connection(self): def create_vpn_connection(self):
type = self.querystring.get("Type", [None])[0] type = self._get_param('Type')
cgw_id = self.querystring.get("CustomerGatewayId", [None])[0] cgw_id = self._get_param('CustomerGatewayId')
vgw_id = self.querystring.get("VPNGatewayId", [None])[0] vgw_id = self._get_param('VPNGatewayId')
static_routes = self.querystring.get("StaticRoutesOnly", [None])[0] static_routes = self._get_param('StaticRoutesOnly')
vpn_connection = self.ec2_backend.create_vpn_connection( vpn_connection = self.ec2_backend.create_vpn_connection(
type, cgw_id, vgw_id, static_routes_only=static_routes) type, cgw_id, vgw_id, static_routes_only=static_routes)
template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE) template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection) return template.render(vpn_connection=vpn_connection)
def delete_vpn_connection(self): def delete_vpn_connection(self):
vpn_connection_id = self.querystring.get('VpnConnectionId')[0] vpn_connection_id = self._get_param('VpnConnectionId')
vpn_connection = self.ec2_backend.delete_vpn_connection( vpn_connection = self.ec2_backend.delete_vpn_connection(
vpn_connection_id) vpn_connection_id)
template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE) template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection) return template.render(vpn_connection=vpn_connection)
def describe_vpn_connections(self): def describe_vpn_connections(self):
vpn_connection_ids = sequence_from_querystring( vpn_connection_ids = self._get_multi_param('VpnConnectionId')
'VpnConnectionId', self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
vpn_connections = self.ec2_backend.get_all_vpn_connections( vpn_connections = self.ec2_backend.get_all_vpn_connections(
vpn_connection_ids=vpn_connection_ids, filters=filters) vpn_connection_ids=vpn_connection_ids, filters=filters)

View File

@ -51,7 +51,7 @@ def random_ami_id():
def random_instance_id(): def random_instance_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX['instance']) return random_id(prefix=EC2_RESOURCE_TO_PREFIX['instance'], size=17)
def random_reservation_id(): def random_reservation_id():
@ -174,62 +174,6 @@ def split_route_id(route_id):
return values[0], values[1] return values[0], values[1]
def instance_ids_from_querystring(querystring_dict):
instance_ids = []
for key, value in querystring_dict.items():
if 'InstanceId' in key:
instance_ids.append(value[0])
return instance_ids
def image_ids_from_querystring(querystring_dict):
image_ids = []
for key, value in querystring_dict.items():
if 'ImageId' in key:
image_ids.append(value[0])
return image_ids
def executable_users_from_querystring(querystring_dict):
user_ids = []
for key, value in querystring_dict.items():
if 'ExecutableBy' in key:
user_ids.append(value[0])
return user_ids
def route_table_ids_from_querystring(querystring_dict):
route_table_ids = []
for key, value in querystring_dict.items():
if 'RouteTableId' in key:
route_table_ids.append(value[0])
return route_table_ids
def network_acl_ids_from_querystring(querystring_dict):
network_acl_ids = []
for key, value in querystring_dict.items():
if 'NetworkAclId' in key:
network_acl_ids.append(value[0])
return network_acl_ids
def vpc_ids_from_querystring(querystring_dict):
vpc_ids = []
for key, value in querystring_dict.items():
if 'VpcId' in key:
vpc_ids.append(value[0])
return vpc_ids
def sequence_from_querystring(parameter, querystring_dict):
parameter_values = []
for key, value in querystring_dict.items():
if parameter in key:
parameter_values.append(value[0])
return parameter_values
def tags_from_query_string(querystring_dict): def tags_from_query_string(querystring_dict):
prefix = 'Tag' prefix = 'Tag'
suffix = 'Key' suffix = 'Key'
@ -286,11 +230,6 @@ def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration'
return response_values return response_values
def optional_from_querystring(parameter, querystring):
parameter_array = querystring.get(parameter)
return parameter_array[0] if parameter_array else None
def filters_from_querystring(querystring_dict): def filters_from_querystring(querystring_dict):
response_values = {} response_values = {}
for key, value in querystring_dict.items(): for key, value in querystring_dict.items():
@ -319,14 +258,6 @@ def dict_from_querystring(parameter, querystring_dict):
return use_dict return use_dict
def keypair_names_from_querystring(querystring_dict):
keypair_names = []
for key, value in querystring_dict.items():
if 'KeyName' in key:
keypair_names.append(value[0])
return keypair_names
def get_object_value(obj, attr): def get_object_value(obj, attr):
keys = attr.split('.') keys = attr.split('.')
val = obj val = obj
@ -335,6 +266,11 @@ def get_object_value(obj, attr):
val = getattr(val, key) val = getattr(val, key)
elif isinstance(val, dict): elif isinstance(val, dict):
val = val[key] val = val[key]
elif isinstance(val, list):
for item in val:
item_val = get_object_value(item, key)
if item_val:
return item_val
else: else:
return None return None
return val return val
@ -385,14 +321,17 @@ filter_dict_attribute_mapping = {
'state-reason-code': '_state_reason.code', 'state-reason-code': '_state_reason.code',
'source-dest-check': 'source_dest_check', 'source-dest-check': 'source_dest_check',
'vpc-id': 'vpc_id', 'vpc-id': 'vpc_id',
'group-id': 'security_groups', 'group-id': 'security_groups.id',
'instance.group-id': 'security_groups', 'instance.group-id': 'security_groups.id',
'instance.group-name': 'security_groups.name',
'instance-type': 'instance_type', 'instance-type': 'instance_type',
'private-ip-address': 'private_ip', 'private-ip-address': 'private_ip',
'ip-address': 'public_ip', 'ip-address': 'public_ip',
'availability-zone': 'placement', 'availability-zone': 'placement',
'architecture': 'architecture', 'architecture': 'architecture',
'image-id': 'image_id' 'image-id': 'image_id',
'network-interface.private-dns-name': 'private_dns',
'private-dns-name': 'private_dns'
} }

7
moto/ecr/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from __future__ import unicode_literals
from .models import ecr_backends
from ..core.models import base_decorator, deprecated_base_decorator
ecr_backend = ecr_backends['us-east-1']
mock_ecr = base_decorator(ecr_backends)
mock_ecr_deprecated = deprecated_base_decorator(ecr_backends)

22
moto/ecr/exceptions.py Normal file
View File

@ -0,0 +1,22 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
class RepositoryNotFoundException(RESTError):
code = 400
def __init__(self, repository_name, registry_id):
super(RepositoryNotFoundException, self).__init__(
error_type="RepositoryNotFoundException",
message="The repository with name '{0}' does not exist in the registry "
"with id '{1}'".format(repository_name, registry_id))
class ImageNotFoundException(RESTError):
code = 400
def __init__(self, image_id, repository_name, registry_id):
super(ImageNotFoundException, self).__init__(
error_type="ImageNotFoundException",
message="The image with imageId {0} does not exist within the repository with name '{1}' "
"in the registry with id '{2}'".format(image_id, repository_name, registry_id))

251
moto/ecr/models.py Normal file
View File

@ -0,0 +1,251 @@
from __future__ import unicode_literals
# from datetime import datetime
from random import random
from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends
from copy import copy
import hashlib
from moto.ecr.exceptions import ImageNotFoundException, RepositoryNotFoundException
DEFAULT_REGISTRY_ID = '012345678910'
class BaseObject(BaseModel):
def camelCase(self, key):
words = []
for i, word in enumerate(key.split('_')):
if i > 0:
words.append(word.title())
else:
words.append(word)
return ''.join(words)
def gen_response_object(self):
response_object = copy(self.__dict__)
for key, value in response_object.items():
if '_' in key:
response_object[self.camelCase(key)] = value
del response_object[key]
return response_object
@property
def response_object(self):
return self.gen_response_object()
class Repository(BaseObject):
def __init__(self, repository_name):
self.registry_id = DEFAULT_REGISTRY_ID
self.arn = 'arn:aws:ecr:us-east-1:{0}:repository/{1}'.format(
self.registry_id, repository_name)
self.name = repository_name
# self.created = datetime.utcnow()
self.uri = '{0}.dkr.ecr.us-east-1.amazonaws.com/{1}'.format(
self.registry_id, repository_name)
self.images = []
@property
def physical_resource_id(self):
return self.name
@property
def response_object(self):
response_object = self.gen_response_object()
response_object['registryId'] = self.registry_id
response_object['repositoryArn'] = self.arn
response_object['repositoryName'] = self.name
response_object['repositoryUri'] = self.uri
# response_object['createdAt'] = self.created
del response_object['arn'], response_object['name'], response_object['images']
return response_object
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
ecr_backend = ecr_backends[region_name]
return ecr_backend.create_repository(
# RepositoryName is optional in CloudFormation, thus create a random
# name if necessary
repository_name=properties.get(
'RepositoryName', 'ecrrepository{0}'.format(int(random() * 10 ** 6))),
)
@classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
if original_resource.name != properties['RepositoryName']:
ecr_backend = ecr_backends[region_name]
ecr_backend.delete_cluster(original_resource.arn)
return ecr_backend.create_repository(
# RepositoryName is optional in CloudFormation, thus create a
# random name if necessary
repository_name=properties.get(
'RepositoryName', 'RepositoryName{0}'.format(int(random() * 10 ** 6))),
)
else:
# no-op when nothing changed between old and new resources
return original_resource
class Image(BaseObject):
def __init__(self, tag, manifest, repository, registry_id=DEFAULT_REGISTRY_ID):
self.image_tag = tag
self.image_manifest = manifest
self.image_size_in_bytes = 50 * 1024 * 1024
self.repository = repository
self.registry_id = registry_id
self.image_digest = None
self.image_pushed_at = None
def _create_digest(self):
image_contents = 'docker_image{0}'.format(int(random() * 10 ** 6))
self.image_digest = "sha256:%s" % hashlib.sha256(image_contents.encode('utf-8')).hexdigest()
def get_image_digest(self):
if not self.image_digest:
self._create_digest()
return self.image_digest
@property
def response_object(self):
response_object = self.gen_response_object()
response_object['imageId'] = {}
response_object['imageId']['imageTag'] = self.image_tag
response_object['imageId']['imageDigest'] = self.get_image_digest()
response_object['imageManifest'] = self.image_manifest
response_object['repositoryName'] = self.repository
response_object['registryId'] = self.registry_id
return response_object
@property
def response_list_object(self):
response_object = self.gen_response_object()
response_object['imageTag'] = self.image_tag
response_object['imageDigest'] = "i don't know"
return response_object
@property
def response_describe_object(self):
response_object = self.gen_response_object()
response_object['imageTags'] = [self.image_tag]
response_object['imageDigest'] = self.get_image_digest()
response_object['imageManifest'] = self.image_manifest
response_object['repositoryName'] = self.repository
response_object['registryId'] = self.registry_id
response_object['imageSizeInBytes'] = self.image_size_in_bytes
response_object['imagePushedAt'] = '2017-05-09'
return response_object
class ECRBackend(BaseBackend):
def __init__(self):
self.repositories = {}
def describe_repositories(self, registry_id=None, repository_names=None):
"""
maxResults and nextToken not implemented
"""
if repository_names:
for repository_name in repository_names:
if repository_name not in self.repositories:
raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID)
repositories = []
for repository in self.repositories.values():
# If a registry_id was supplied, ensure this repository matches
if registry_id:
if repository.registry_id != registry_id:
continue
# If a list of repository names was supplied, esure this repository
# is in that list
if repository_names:
if repository.name not in repository_names:
continue
repositories.append(repository.response_object)
return repositories
def create_repository(self, repository_name):
repository = Repository(repository_name)
self.repositories[repository_name] = repository
return repository
def delete_repository(self, repository_name, registry_id=None):
if repository_name in self.repositories:
return self.repositories.pop(repository_name)
else:
raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID)
def list_images(self, repository_name, registry_id=None):
"""
maxResults and filtering not implemented
"""
images = []
for repository in self.repositories.values():
if repository_name:
if repository.name != repository_name:
continue
if registry_id:
if repository.registry_id != registry_id:
continue
for image in repository.images:
images.append(image)
return images
def describe_images(self, repository_name, registry_id=None, image_ids=None):
if repository_name in self.repositories:
repository = self.repositories[repository_name]
else:
raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID)
if image_ids:
response = set()
for image_id in image_ids:
found = False
for image in repository.images:
if (('imageDigest' in image_id and image.get_image_digest() == image_id['imageDigest']) or
('imageTag' in image_id and image.image_tag == image_id['imageTag'])):
found = True
response.add(image)
if not found:
image_id_representation = "{imageDigest:'%s', imageTag:'%s'}" % (
image_id.get('imageDigest', 'null'),
image_id.get('imageTag', 'null'),
)
raise ImageNotFoundException(
image_id=image_id_representation,
repository_name=repository_name,
registry_id=registry_id or DEFAULT_REGISTRY_ID)
else:
response = []
for image in repository.images:
response.append(image)
return response
def put_image(self, repository_name, image_manifest, image_tag):
if repository_name in self.repositories:
repository = self.repositories[repository_name]
else:
raise Exception("{0} is not a repository".format(repository_name))
image = Image(image_tag, image_manifest, repository_name)
repository.images.append(image)
return image
ecr_backends = {}
for region, ec2_backend in ec2_backends.items():
ecr_backends[region] = ECRBackend()

164
moto/ecr/responses.py Normal file
View File

@ -0,0 +1,164 @@
from __future__ import unicode_literals
import json
from base64 import b64encode
from datetime import datetime
import time
from moto.core.responses import BaseResponse
from .models import ecr_backends
class ECRResponse(BaseResponse):
@property
def ecr_backend(self):
return ecr_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param):
return self.request_params.get(param, None)
def create_repository(self):
repository_name = self._get_param('repositoryName')
if repository_name is None:
repository_name = 'default'
repository = self.ecr_backend.create_repository(repository_name)
return json.dumps({
'repository': repository.response_object
})
def describe_repositories(self):
describe_repositories_name = self._get_param('repositoryNames')
registry_id = self._get_param('registryId')
repositories = self.ecr_backend.describe_repositories(
repository_names=describe_repositories_name, registry_id=registry_id)
return json.dumps({
'repositories': repositories,
'failures': []
})
def delete_repository(self):
repository_str = self._get_param('repositoryName')
registry_id = self._get_param('registryId')
repository = self.ecr_backend.delete_repository(repository_str, registry_id)
return json.dumps({
'repository': repository.response_object
})
def put_image(self):
repository_str = self._get_param('repositoryName')
image_manifest = self._get_param('imageManifest')
image_tag = self._get_param('imageTag')
image = self.ecr_backend.put_image(repository_str, image_manifest, image_tag)
return json.dumps({
'image': image.response_object
})
def list_images(self):
repository_str = self._get_param('repositoryName')
registry_id = self._get_param('registryId')
images = self.ecr_backend.list_images(repository_str, registry_id)
return json.dumps({
'imageIds': [image.response_list_object for image in images],
})
def describe_images(self):
repository_str = self._get_param('repositoryName')
registry_id = self._get_param('registryId')
image_ids = self._get_param('imageIds')
images = self.ecr_backend.describe_images(repository_str, registry_id, image_ids)
return json.dumps({
'imageDetails': [image.response_describe_object for image in images],
})
def batch_check_layer_availability(self):
if self.is_not_dryrun('BatchCheckLayerAvailability'):
raise NotImplementedError(
'ECR.batch_check_layer_availability is not yet implemented')
def batch_delete_image(self):
if self.is_not_dryrun('BatchDeleteImage'):
raise NotImplementedError(
'ECR.batch_delete_image is not yet implemented')
def batch_get_image(self):
if self.is_not_dryrun('BatchGetImage'):
raise NotImplementedError(
'ECR.batch_get_image is not yet implemented')
def can_paginate(self):
if self.is_not_dryrun('CanPaginate'):
raise NotImplementedError(
'ECR.can_paginate is not yet implemented')
def complete_layer_upload(self):
if self.is_not_dryrun('CompleteLayerUpload'):
raise NotImplementedError(
'ECR.complete_layer_upload is not yet implemented')
def delete_repository_policy(self):
if self.is_not_dryrun('DeleteRepositoryPolicy'):
raise NotImplementedError(
'ECR.delete_repository_policy is not yet implemented')
def generate_presigned_url(self):
if self.is_not_dryrun('GeneratePresignedUrl'):
raise NotImplementedError(
'ECR.generate_presigned_url is not yet implemented')
def get_authorization_token(self):
registry_ids = self._get_param('registryIds')
if not registry_ids:
registry_ids = [self.region]
auth_data = []
for registry_id in registry_ids:
password = '{}-auth-token'.format(registry_id)
auth_token = b64encode("AWS:{}".format(password).encode('ascii')).decode()
auth_data.append({
'authorizationToken': auth_token,
'expiresAt': time.mktime(datetime(2015, 1, 1).timetuple()),
'proxyEndpoint': 'https://012345678910.dkr.ecr.{}.amazonaws.com'.format(registry_id)
})
return json.dumps({'authorizationData': auth_data})
def get_download_url_for_layer(self):
if self.is_not_dryrun('GetDownloadUrlForLayer'):
raise NotImplementedError(
'ECR.get_download_url_for_layer is not yet implemented')
def get_paginator(self):
if self.is_not_dryrun('GetPaginator'):
raise NotImplementedError(
'ECR.get_paginator is not yet implemented')
def get_repository_policy(self):
if self.is_not_dryrun('GetRepositoryPolicy'):
raise NotImplementedError(
'ECR.get_repository_policy is not yet implemented')
def get_waiter(self):
if self.is_not_dryrun('GetWaiter'):
raise NotImplementedError(
'ECR.get_waiter is not yet implemented')
def initiate_layer_upload(self):
if self.is_not_dryrun('InitiateLayerUpload'):
raise NotImplementedError(
'ECR.initiate_layer_upload is not yet implemented')
def set_repository_policy(self):
if self.is_not_dryrun('SetRepositoryPolicy'):
raise NotImplementedError(
'ECR.set_repository_policy is not yet implemented')
def upload_layer_part(self):
if self.is_not_dryrun('UploadLayerPart'):
raise NotImplementedError(
'ECR.upload_layer_part is not yet implemented')

10
moto/ecr/urls.py Normal file
View File

@ -0,0 +1,10 @@
from __future__ import unicode_literals
from .responses import ECRResponse
url_bases = [
"https?://ecr.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': ECRResponse.dispatch,
}

View File

@ -114,7 +114,7 @@ class TaskDefinition(BaseObject):
family = properties.get( family = properties.get(
'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6))) 'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6)))
container_definitions = properties['ContainerDefinitions'] container_definitions = properties['ContainerDefinitions']
volumes = properties['Volumes'] volumes = properties.get('Volumes')
ecs_backend = ecs_backends[region_name] ecs_backend = ecs_backends[region_name]
return ecs_backend.register_task_definition( return ecs_backend.register_task_definition(
@ -127,7 +127,7 @@ class TaskDefinition(BaseObject):
family = properties.get( family = properties.get(
'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6))) 'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6)))
container_definitions = properties['ContainerDefinitions'] container_definitions = properties['ContainerDefinitions']
volumes = properties['Volumes'] volumes = properties.get('Volumes')
if (original_resource.family != family or if (original_resource.family != family or
original_resource.container_definitions != container_definitions or original_resource.container_definitions != container_definitions or
original_resource.volumes != volumes): original_resource.volumes != volumes):
@ -289,7 +289,7 @@ class ContainerInstance(BaseObject):
'type': 'STRINGSET'}] 'type': 'STRINGSET'}]
self.container_instance_arn = "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format( self.container_instance_arn = "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format(
str(uuid.uuid1())) str(uuid.uuid1()))
self.pending_task_count = 0 self.pending_tasks_count = 0
self.remaining_resources = [ self.remaining_resources = [
{'doubleValue': 0.0, {'doubleValue': 0.0,
'integerValue': 4096, 'integerValue': 4096,
@ -314,7 +314,7 @@ class ContainerInstance(BaseObject):
'stringSetValue': [], 'stringSetValue': [],
'type': 'STRINGSET'} 'type': 'STRINGSET'}
] ]
self.running_task_count = 0 self.running_tasks_count = 0
self.version_info = { self.version_info = {
'agentVersion': "1.0.0", 'agentVersion': "1.0.0",
'agentHash': '4023248', 'agentHash': '4023248',
@ -737,7 +737,7 @@ class EC2ContainerServiceBackend(BaseBackend):
resource["stringSetValue"].remove(str(port)) resource["stringSetValue"].remove(str(port))
else: else:
resource["stringSetValue"].append(str(port)) resource["stringSetValue"].append(str(port))
container_instance.running_task_count += resource_multiplier * 1 container_instance.running_tasks_count += resource_multiplier * 1
def deregister_container_instance(self, cluster_str, container_instance_str, force): def deregister_container_instance(self, cluster_str, container_instance_str, force):
failures = [] failures = []
@ -748,11 +748,11 @@ class EC2ContainerServiceBackend(BaseBackend):
container_instance = self.container_instances[cluster_name].get(container_instance_id) container_instance = self.container_instances[cluster_name].get(container_instance_id)
if container_instance is None: if container_instance is None:
raise Exception("{0} is not a container id in the cluster") raise Exception("{0} is not a container id in the cluster")
if not force and container_instance.running_task_count > 0: if not force and container_instance.running_tasks_count > 0:
raise Exception("Found running tasks on the instance.") raise Exception("Found running tasks on the instance.")
# Currently assume that people might want to do something based around deregistered instances # Currently assume that people might want to do something based around deregistered instances
# with tasks left running on them - but nothing if no tasks were running already # with tasks left running on them - but nothing if no tasks were running already
elif force and container_instance.running_task_count > 0: elif force and container_instance.running_tasks_count > 0:
if not self.container_instances.get('orphaned'): if not self.container_instances.get('orphaned'):
self.container_instances['orphaned'] = {} self.container_instances['orphaned'] = {}
self.container_instances['orphaned'][container_instance_id] = container_instance self.container_instances['orphaned'][container_instance_id] = container_instance

View File

@ -18,8 +18,8 @@ class EC2ContainerServiceResponse(BaseResponse):
except ValueError: except ValueError:
return {} return {}
def _get_param(self, param): def _get_param(self, param, if_none=None):
return self.request_params.get(param, None) return self.request_params.get(param, if_none)
def create_cluster(self): def create_cluster(self):
cluster_name = self._get_param('clusterName') cluster_name = self._get_param('clusterName')

View File

@ -40,6 +40,15 @@ class BadHealthCheckDefinition(ELBClientError):
"HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL") "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL")
class DuplicateListenerError(ELBClientError):
def __init__(self, name, port):
super(DuplicateListenerError, self).__init__(
"DuplicateListener",
"A listener already exists for {0} with LoadBalancerPort {1}, but with a different InstancePort, Protocol, or SSLCertificateId"
.format(name, port))
class DuplicateLoadBalancerName(ELBClientError): class DuplicateLoadBalancerName(ELBClientError):
def __init__(self, name): def __init__(self, name):
@ -47,3 +56,19 @@ class DuplicateLoadBalancerName(ELBClientError):
"DuplicateLoadBalancerName", "DuplicateLoadBalancerName",
"The specified load balancer name already exists for this account: {0}" "The specified load balancer name already exists for this account: {0}"
.format(name)) .format(name))
class EmptyListenersError(ELBClientError):
def __init__(self):
super(EmptyListenersError, self).__init__(
"ValidationError",
"Listeners cannot be empty")
class InvalidSecurityGroupError(ELBClientError):
def __init__(self):
super(InvalidSecurityGroupError, self).__init__(
"ValidationError",
"One or more of the specified security groups do not exist.")

View File

@ -16,10 +16,13 @@ from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.ec2.models import ec2_backends from moto.ec2.models import ec2_backends
from .exceptions import ( from .exceptions import (
LoadBalancerNotFoundError,
TooManyTagsError,
BadHealthCheckDefinition, BadHealthCheckDefinition,
DuplicateLoadBalancerName, DuplicateLoadBalancerName,
DuplicateListenerError,
EmptyListenersError,
InvalidSecurityGroupError,
LoadBalancerNotFoundError,
TooManyTagsError,
) )
@ -61,7 +64,7 @@ class FakeBackend(BaseModel):
class FakeLoadBalancer(BaseModel): class FakeLoadBalancer(BaseModel):
def __init__(self, name, zones, ports, scheme='internet-facing', vpc_id=None, subnets=None): def __init__(self, name, zones, ports, scheme='internet-facing', vpc_id=None, subnets=None, security_groups=None):
self.name = name self.name = name
self.health_check = None self.health_check = None
self.instance_ids = [] self.instance_ids = []
@ -75,6 +78,7 @@ class FakeLoadBalancer(BaseModel):
self.policies.other_policies = [] self.policies.other_policies = []
self.policies.app_cookie_stickiness_policies = [] self.policies.app_cookie_stickiness_policies = []
self.policies.lb_cookie_stickiness_policies = [] self.policies.lb_cookie_stickiness_policies = []
self.security_groups = security_groups or []
self.subnets = subnets or [] self.subnets = subnets or []
self.vpc_id = vpc_id or 'vpc-56e10e3d' self.vpc_id = vpc_id or 'vpc-56e10e3d'
self.tags = {} self.tags = {}
@ -231,7 +235,7 @@ class ELBBackend(BaseBackend):
self.__dict__ = {} self.__dict__ = {}
self.__init__(region_name) self.__init__(region_name)
def create_load_balancer(self, name, zones, ports, scheme='internet-facing', subnets=None): def create_load_balancer(self, name, zones, ports, scheme='internet-facing', subnets=None, security_groups=None):
vpc_id = None vpc_id = None
ec2_backend = ec2_backends[self.region_name] ec2_backend = ec2_backends[self.region_name]
if subnets: if subnets:
@ -239,8 +243,21 @@ class ELBBackend(BaseBackend):
vpc_id = subnet.vpc_id vpc_id = subnet.vpc_id
if name in self.load_balancers: if name in self.load_balancers:
raise DuplicateLoadBalancerName(name) raise DuplicateLoadBalancerName(name)
if not ports:
raise EmptyListenersError()
if not security_groups:
security_groups = []
for security_group in security_groups:
if ec2_backend.get_security_group_from_id(security_group) is None:
raise InvalidSecurityGroupError()
new_load_balancer = FakeLoadBalancer( new_load_balancer = FakeLoadBalancer(
name=name, zones=zones, ports=ports, scheme=scheme, subnets=subnets, vpc_id=vpc_id) name=name,
zones=zones,
ports=ports,
scheme=scheme,
subnets=subnets,
security_groups=security_groups,
vpc_id=vpc_id)
self.load_balancers[name] = new_load_balancer self.load_balancers[name] = new_load_balancer
return new_load_balancer return new_load_balancer
@ -254,6 +271,12 @@ class ELBBackend(BaseBackend):
ssl_certificate_id = port.get('sslcertificate_id') ssl_certificate_id = port.get('sslcertificate_id')
for listener in balancer.listeners: for listener in balancer.listeners:
if lb_port == listener.load_balancer_port: if lb_port == listener.load_balancer_port:
if protocol != listener.protocol:
raise DuplicateListenerError(name, lb_port)
if instance_port != listener.instance_port:
raise DuplicateListenerError(name, lb_port)
if ssl_certificate_id != listener.ssl_certificate_id:
raise DuplicateListenerError(name, lb_port)
break break
else: else:
balancer.listeners.append(FakeListener( balancer.listeners.append(FakeListener(
@ -292,6 +315,14 @@ class ELBBackend(BaseBackend):
def get_load_balancer(self, load_balancer_name): def get_load_balancer(self, load_balancer_name):
return self.load_balancers.get(load_balancer_name) return self.load_balancers.get(load_balancer_name)
def apply_security_groups_to_load_balancer(self, load_balancer_name, security_group_ids):
load_balancer = self.load_balancers.get(load_balancer_name)
ec2_backend = ec2_backends[self.region_name]
for security_group_id in security_group_ids:
if ec2_backend.get_security_group_from_id(security_group_id) is None:
raise InvalidSecurityGroupError()
load_balancer.security_groups = security_group_ids
def configure_health_check(self, load_balancer_name, timeout, def configure_health_check(self, load_balancer_name, timeout,
healthy_threshold, unhealthy_threshold, interval, healthy_threshold, unhealthy_threshold, interval,
target): target):

View File

@ -27,6 +27,7 @@ class ELBResponse(BaseResponse):
ports = self._get_list_prefix("Listeners.member") ports = self._get_list_prefix("Listeners.member")
scheme = self._get_param('Scheme') scheme = self._get_param('Scheme')
subnets = self._get_multi_param("Subnets.member") subnets = self._get_multi_param("Subnets.member")
security_groups = self._get_multi_param("SecurityGroups.member")
load_balancer = self.elb_backend.create_load_balancer( load_balancer = self.elb_backend.create_load_balancer(
name=load_balancer_name, name=load_balancer_name,
@ -34,6 +35,7 @@ class ELBResponse(BaseResponse):
ports=ports, ports=ports,
scheme=scheme, scheme=scheme,
subnets=subnets, subnets=subnets,
security_groups=security_groups,
) )
self._add_tags(load_balancer) self._add_tags(load_balancer)
template = self.response_template(CREATE_LOAD_BALANCER_TEMPLATE) template = self.response_template(CREATE_LOAD_BALANCER_TEMPLATE)
@ -84,6 +86,13 @@ class ELBResponse(BaseResponse):
template = self.response_template(DELETE_LOAD_BALANCER_TEMPLATE) template = self.response_template(DELETE_LOAD_BALANCER_TEMPLATE)
return template.render() return template.render()
def apply_security_groups_to_load_balancer(self):
load_balancer_name = self._get_param('LoadBalancerName')
security_group_ids = self._get_multi_param("SecurityGroups.member")
self.elb_backend.apply_security_groups_to_load_balancer(load_balancer_name, security_group_ids)
template = self.response_template(APPLY_SECURITY_GROUPS_TEMPLATE)
return template.render(security_group_ids=security_group_ids)
def configure_health_check(self): def configure_health_check(self):
check = self.elb_backend.configure_health_check( check = self.elb_backend.configure_health_check(
load_balancer_name=self._get_param('LoadBalancerName'), load_balancer_name=self._get_param('LoadBalancerName'),
@ -99,8 +108,7 @@ class ELBResponse(BaseResponse):
def register_instances_with_load_balancer(self): def register_instances_with_load_balancer(self):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
instance_ids = [value[0] for key, value in self.querystring.items( instance_ids = [list(param.values())[0] for param in self._get_list_prefix('Instances.member')]
) if "Instances.member" in key]
template = self.response_template(REGISTER_INSTANCES_TEMPLATE) template = self.response_template(REGISTER_INSTANCES_TEMPLATE)
load_balancer = self.elb_backend.register_instances( load_balancer = self.elb_backend.register_instances(
load_balancer_name, instance_ids) load_balancer_name, instance_ids)
@ -119,8 +127,7 @@ class ELBResponse(BaseResponse):
def deregister_instances_from_load_balancer(self): def deregister_instances_from_load_balancer(self):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
instance_ids = [value[0] for key, value in self.querystring.items( instance_ids = [list(param.values())[0] for param in self._get_list_prefix('Instances.member')]
) if "Instances.member" in key]
template = self.response_template(DEREGISTER_INSTANCES_TEMPLATE) template = self.response_template(DEREGISTER_INSTANCES_TEMPLATE)
load_balancer = self.elb_backend.deregister_instances( load_balancer = self.elb_backend.deregister_instances(
load_balancer_name, instance_ids) load_balancer_name, instance_ids)
@ -159,9 +166,8 @@ class ELBResponse(BaseResponse):
if connection_draining: if connection_draining:
attribute = ConnectionDrainingAttribute() attribute = ConnectionDrainingAttribute()
attribute.enabled = connection_draining["enabled"] == "true" attribute.enabled = connection_draining["enabled"] == "true"
attribute.timeout = connection_draining["timeout"] attribute.timeout = connection_draining.get("timeout", 300)
self.elb_backend.set_connection_draining_attribute( self.elb_backend.set_connection_draining_attribute(load_balancer_name, attribute)
load_balancer_name, attribute)
connection_settings = self._get_dict_param( connection_settings = self._get_dict_param(
"LoadBalancerAttributes.ConnectionSettings.") "LoadBalancerAttributes.ConnectionSettings.")
@ -172,7 +178,7 @@ class ELBResponse(BaseResponse):
load_balancer_name, attribute) load_balancer_name, attribute)
template = self.response_template(MODIFY_ATTRIBUTES_TEMPLATE) template = self.response_template(MODIFY_ATTRIBUTES_TEMPLATE)
return template.render(attributes=load_balancer.attributes) return template.render(load_balancer=load_balancer, attributes=load_balancer.attributes)
def create_load_balancer_policy(self): def create_load_balancer_policy(self):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
@ -253,8 +259,7 @@ class ELBResponse(BaseResponse):
def describe_instance_health(self): def describe_instance_health(self):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
instance_ids = [value[0] for key, value in self.querystring.items( instance_ids = [list(param.values())[0] for param in self._get_list_prefix('Instances.member')]
) if "Instances.member" in key]
if len(instance_ids) == 0: if len(instance_ids) == 0:
instance_ids = self.elb_backend.get_load_balancer( instance_ids = self.elb_backend.get_load_balancer(
load_balancer_name).instance_ids load_balancer_name).instance_ids
@ -401,6 +406,9 @@ DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http
{% for load_balancer in load_balancers %} {% for load_balancer in load_balancers %}
<member> <member>
<SecurityGroups> <SecurityGroups>
{% for security_group_id in load_balancer.security_groups %}
<member>{{ security_group_id }}</member>
{% endfor %}
</SecurityGroups> </SecurityGroups>
<LoadBalancerName>{{ load_balancer.name }}</LoadBalancerName> <LoadBalancerName>{{ load_balancer.name }}</LoadBalancerName>
<CreatedTime>{{ load_balancer.created_time }}</CreatedTime> <CreatedTime>{{ load_balancer.created_time }}</CreatedTime>
@ -514,6 +522,19 @@ DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http
</ResponseMetadata> </ResponseMetadata>
</DescribeLoadBalancersResponse>""" </DescribeLoadBalancersResponse>"""
APPLY_SECURITY_GROUPS_TEMPLATE = """<ApplySecurityGroupsToLoadBalancerResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2012-06-01/">
<ApplySecurityGroupsToLoadBalancerResult>
<SecurityGroups>
{% for security_group_id in security_group_ids %}
<member>{{ security_group_id }}</member>
{% endfor %}
</SecurityGroups>
</ApplySecurityGroupsToLoadBalancerResult>
<ResponseMetadata>
<RequestId>f9880f01-7852-629d-a6c3-3ae2-666a409287e6dc0c</RequestId>
</ResponseMetadata>
</ApplySecurityGroupsToLoadBalancerResponse>"""
CONFIGURE_HEALTH_CHECK_TEMPLATE = """<ConfigureHealthCheckResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2012-06-01/"> CONFIGURE_HEALTH_CHECK_TEMPLATE = """<ConfigureHealthCheckResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2012-06-01/">
<ConfigureHealthCheckResult> <ConfigureHealthCheckResult>
<HealthCheck> <HealthCheck>
@ -592,9 +613,11 @@ DESCRIBE_ATTRIBUTES_TEMPLATE = """<DescribeLoadBalancerAttributesResponse xmlns
<Enabled>{{ attributes.cross_zone_load_balancing.enabled }}</Enabled> <Enabled>{{ attributes.cross_zone_load_balancing.enabled }}</Enabled>
</CrossZoneLoadBalancing> </CrossZoneLoadBalancing>
<ConnectionDraining> <ConnectionDraining>
<Enabled>{{ attributes.connection_draining.enabled }}</Enabled>
{% if attributes.connection_draining.enabled %} {% if attributes.connection_draining.enabled %}
<Enabled>true</Enabled>
<Timeout>{{ attributes.connection_draining.timeout }}</Timeout> <Timeout>{{ attributes.connection_draining.timeout }}</Timeout>
{% else %}
<Enabled>false</Enabled>
{% endif %} {% endif %}
</ConnectionDraining> </ConnectionDraining>
</LoadBalancerAttributes> </LoadBalancerAttributes>
@ -607,7 +630,7 @@ DESCRIBE_ATTRIBUTES_TEMPLATE = """<DescribeLoadBalancerAttributesResponse xmlns
MODIFY_ATTRIBUTES_TEMPLATE = """<ModifyLoadBalancerAttributesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2012-06-01/"> MODIFY_ATTRIBUTES_TEMPLATE = """<ModifyLoadBalancerAttributesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2012-06-01/">
<ModifyLoadBalancerAttributesResult> <ModifyLoadBalancerAttributesResult>
<LoadBalancerName>my-loadbalancer</LoadBalancerName> <LoadBalancerName>{{ load_balancer.name }}</LoadBalancerName>
<LoadBalancerAttributes> <LoadBalancerAttributes>
<AccessLog> <AccessLog>
<Enabled>{{ attributes.access_log.enabled }}</Enabled> <Enabled>{{ attributes.access_log.enabled }}</Enabled>
@ -624,9 +647,11 @@ MODIFY_ATTRIBUTES_TEMPLATE = """<ModifyLoadBalancerAttributesResponse xmlns="htt
<Enabled>{{ attributes.cross_zone_load_balancing.enabled }}</Enabled> <Enabled>{{ attributes.cross_zone_load_balancing.enabled }}</Enabled>
</CrossZoneLoadBalancing> </CrossZoneLoadBalancing>
<ConnectionDraining> <ConnectionDraining>
<Enabled>{{ attributes.connection_draining.enabled }}</Enabled>
{% if attributes.connection_draining.enabled %} {% if attributes.connection_draining.enabled %}
<Enabled>true</Enabled>
<Timeout>{{ attributes.connection_draining.timeout }}</Timeout> <Timeout>{{ attributes.connection_draining.timeout }}</Timeout>
{% else %}
<Enabled>false</Enabled>
{% endif %} {% endif %}
</ConnectionDraining> </ConnectionDraining>
</LoadBalancerAttributes> </LoadBalancerAttributes>

View File

@ -1,10 +1,44 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import ELBResponse from six.moves.urllib.parse import parse_qs
from botocore.awsrequest import AWSPreparedRequest
from moto.elb.responses import ELBResponse
from moto.elbv2.responses import ELBV2Response
def api_version_elb_backend(*args, **kwargs):
"""
ELB and ELBV2 (Classic and Application load balancers) use the same
hostname and url space. To differentiate them we must read the
`Version` parameter out of the url-encoded request body. TODO: There
has _got_ to be a better way to do this. Please help us think of
one.
"""
request = args[0]
if hasattr(request, 'values'):
# boto3
version = request.values.get('Version')
elif isinstance(request, AWSPreparedRequest):
# boto in-memory
version = parse_qs(request.body).get('Version')[0]
else:
# boto in server mode
request.parse_request()
version = request.querystring.get('Version')[0]
if '2012-06-01' == version:
return ELBResponse.dispatch(*args, **kwargs)
elif '2015-12-01' == version:
return ELBV2Response.dispatch(*args, **kwargs)
else:
raise Exception("Unknown ELB API version: {}".format(version))
url_bases = [ url_bases = [
"https?://elasticloadbalancing.(.+).amazonaws.com", "https?://elasticloadbalancing.(.+).amazonaws.com",
] ]
url_paths = { url_paths = {
'{0}/$': ELBResponse.dispatch, '{0}/$': api_version_elb_backend,
} }

6
moto/elbv2/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import elbv2_backends
from ..core.models import base_decorator
elb_backend = elbv2_backends['us-east-1']
mock_elbv2 = base_decorator(elbv2_backends)

192
moto/elbv2/exceptions.py Normal file
View File

@ -0,0 +1,192 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
class ELBClientError(RESTError):
code = 400
class DuplicateTagKeysError(ELBClientError):
def __init__(self, cidr):
super(DuplicateTagKeysError, self).__init__(
"DuplicateTagKeys",
"Tag key was specified more than once: {0}"
.format(cidr))
class LoadBalancerNotFoundError(ELBClientError):
def __init__(self):
super(LoadBalancerNotFoundError, self).__init__(
"LoadBalancerNotFound",
"The specified load balancer does not exist.")
class ListenerNotFoundError(ELBClientError):
def __init__(self):
super(ListenerNotFoundError, self).__init__(
"ListenerNotFound",
"The specified listener does not exist.")
class SubnetNotFoundError(ELBClientError):
def __init__(self):
super(SubnetNotFoundError, self).__init__(
"SubnetNotFound",
"The specified subnet does not exist.")
class TargetGroupNotFoundError(ELBClientError):
def __init__(self):
super(TargetGroupNotFoundError, self).__init__(
"TargetGroupNotFound",
"The specified target group does not exist.")
class TooManyTagsError(ELBClientError):
def __init__(self):
super(TooManyTagsError, self).__init__(
"TooManyTagsError",
"The quota for the number of tags that can be assigned to a load balancer has been reached")
class BadHealthCheckDefinition(ELBClientError):
def __init__(self):
super(BadHealthCheckDefinition, self).__init__(
"ValidationError",
"HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL")
class DuplicateListenerError(ELBClientError):
def __init__(self):
super(DuplicateListenerError, self).__init__(
"DuplicateListener",
"A listener with the specified port already exists.")
class DuplicateLoadBalancerName(ELBClientError):
def __init__(self):
super(DuplicateLoadBalancerName, self).__init__(
"DuplicateLoadBalancerName",
"A load balancer with the specified name already exists.")
class DuplicateTargetGroupName(ELBClientError):
def __init__(self):
super(DuplicateTargetGroupName, self).__init__(
"DuplicateTargetGroupName",
"A target group with the specified name already exists.")
class InvalidTargetError(ELBClientError):
def __init__(self):
super(InvalidTargetError, self).__init__(
"InvalidTarget",
"The specified target does not exist or is not in the same VPC as the target group.")
class EmptyListenersError(ELBClientError):
def __init__(self):
super(EmptyListenersError, self).__init__(
"ValidationError",
"Listeners cannot be empty")
class PriorityInUseError(ELBClientError):
def __init__(self):
super(PriorityInUseError, self).__init__(
"PriorityInUse",
"The specified priority is in use.")
class InvalidConditionFieldError(ELBClientError):
def __init__(self, invalid_name):
super(InvalidConditionFieldError, self).__init__(
"ValidationError",
"Condition field '%s' must be one of '[path-pattern, host-header]" % (invalid_name))
class InvalidConditionValueError(ELBClientError):
def __init__(self, msg):
super(InvalidConditionValueError, self).__init__(
"ValidationError", msg)
class InvalidActionTypeError(ELBClientError):
def __init__(self, invalid_name, index):
super(InvalidActionTypeError, self).__init__(
"ValidationError",
"1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward]" % (invalid_name, index)
)
class ActionTargetGroupNotFoundError(ELBClientError):
def __init__(self, arn):
super(ActionTargetGroupNotFoundError, self).__init__(
"TargetGroupNotFound",
"Target group '%s' not found" % arn
)
class InvalidDescribeRulesRequest(ELBClientError):
def __init__(self, msg):
super(InvalidDescribeRulesRequest, self).__init__(
"ValidationError", msg
)
class ResourceInUseError(ELBClientError):
def __init__(self, msg="A specified resource is in use"):
super(ResourceInUseError, self).__init__(
"ResourceInUse", msg)
class RuleNotFoundError(ELBClientError):
def __init__(self):
super(RuleNotFoundError, self).__init__(
"RuleNotFound",
"The specified rule does not exist.")
class DuplicatePriorityError(ELBClientError):
def __init__(self, invalid_value):
super(DuplicatePriorityError, self).__init__(
"ValidationError",
"Priority '%s' was provided multiple times" % invalid_value)
class InvalidTargetGroupNameError(ELBClientError):
def __init__(self, msg):
super(InvalidTargetGroupNameError, self).__init__(
"ValidationError", msg
)
class InvalidModifyRuleArgumentsError(ELBClientError):
def __init__(self):
super(InvalidModifyRuleArgumentsError, self).__init__(
"ValidationError",
"Either conditions or actions must be specified"
)

562
moto/elbv2/models.py Normal file
View File

@ -0,0 +1,562 @@
from __future__ import unicode_literals
import datetime
import re
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.ec2.models import ec2_backends
from .exceptions import (
DuplicateLoadBalancerName,
DuplicateListenerError,
DuplicateTargetGroupName,
InvalidTargetError,
ListenerNotFoundError,
LoadBalancerNotFoundError,
SubnetNotFoundError,
TargetGroupNotFoundError,
TooManyTagsError,
PriorityInUseError,
InvalidConditionFieldError,
InvalidConditionValueError,
InvalidActionTypeError,
ActionTargetGroupNotFoundError,
InvalidDescribeRulesRequest,
ResourceInUseError,
RuleNotFoundError,
DuplicatePriorityError,
InvalidTargetGroupNameError,
InvalidModifyRuleArgumentsError
)
class FakeHealthStatus(BaseModel):
def __init__(self, instance_id, port, health_port, status, reason=None):
self.instance_id = instance_id
self.port = port
self.health_port = health_port
self.status = status
self.reason = reason
class FakeTargetGroup(BaseModel):
def __init__(self,
name,
arn,
vpc_id,
protocol,
port,
healthcheck_protocol,
healthcheck_port,
healthcheck_path,
healthcheck_interval_seconds,
healthcheck_timeout_seconds,
healthy_threshold_count,
unhealthy_threshold_count):
self.name = name
self.arn = arn
self.vpc_id = vpc_id
self.protocol = protocol
self.port = port
self.healthcheck_protocol = healthcheck_protocol
self.healthcheck_port = healthcheck_port
self.healthcheck_path = healthcheck_path
self.healthcheck_interval_seconds = healthcheck_interval_seconds
self.healthcheck_timeout_seconds = healthcheck_timeout_seconds
self.healthy_threshold_count = healthy_threshold_count
self.unhealthy_threshold_count = unhealthy_threshold_count
self.load_balancer_arns = []
self.tags = {}
self.attributes = {
'deregistration_delay.timeout_seconds': 300,
'stickiness.enabled': 'false',
}
self.targets = OrderedDict()
def register(self, targets):
for target in targets:
self.targets[target['id']] = {
'id': target['id'],
'port': target.get('port', self.port),
}
def deregister(self, targets):
for target in targets:
t = self.targets.pop(target['id'], None)
if not t:
raise InvalidTargetError()
def add_tag(self, key, value):
if len(self.tags) >= 10 and key not in self.tags:
raise TooManyTagsError()
self.tags[key] = value
def health_for(self, target):
t = self.targets.get(target['id'])
if t is None:
raise InvalidTargetError()
return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy')
class FakeListener(BaseModel):
def __init__(self, load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions):
self.load_balancer_arn = load_balancer_arn
self.arn = arn
self.protocol = protocol.upper()
self.port = port
self.ssl_policy = ssl_policy
self.certificate = certificate
self.default_actions = default_actions
self._non_default_rules = []
self._default_rule = FakeRule(
listener_arn=self.arn,
conditions=[],
priority='default',
actions=default_actions,
is_default=True
)
@property
def rules(self):
return self._non_default_rules + [self._default_rule]
def remove_rule(self, rule):
self._non_default_rules.remove(rule)
def register(self, rule):
self._non_default_rules.append(rule)
self._non_default_rules = sorted(self._non_default_rules, key=lambda x: x.priority)
class FakeRule(BaseModel):
def __init__(self, listener_arn, conditions, priority, actions, is_default):
self.listener_arn = listener_arn
self.arn = listener_arn.replace(':listener/', ':listener-rule/') + "/%s" % (id(self))
self.conditions = conditions
self.priority = priority # int or 'default'
self.actions = actions
self.is_default = is_default
class FakeBackend(BaseModel):
def __init__(self, instance_port):
self.instance_port = instance_port
self.policy_names = []
def __repr__(self):
return "FakeBackend(inp: %s, policies: %s)" % (self.instance_port, self.policy_names)
class FakeLoadBalancer(BaseModel):
def __init__(self, name, security_groups, subnets, vpc_id, arn, dns_name, scheme='internet-facing'):
self.name = name
self.created_time = datetime.datetime.now()
self.scheme = scheme
self.security_groups = security_groups
self.subnets = subnets or []
self.vpc_id = vpc_id
self.listeners = OrderedDict()
self.tags = {}
self.arn = arn
self.dns_name = dns_name
@property
def physical_resource_id(self):
return self.name
def add_tag(self, key, value):
if len(self.tags) >= 10 and key not in self.tags:
raise TooManyTagsError()
self.tags[key] = value
def list_tags(self):
return self.tags
def remove_tag(self, key):
if key in self.tags:
del self.tags[key]
def delete(self, region):
''' Not exposed as part of the ELB API - used for CloudFormation. '''
elbv2_backends[region].delete_load_balancer(self.arn)
class ELBv2Backend(BaseBackend):
def __init__(self, region_name=None):
self.region_name = region_name
self.target_groups = OrderedDict()
self.load_balancers = OrderedDict()
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_load_balancer(self, name, security_groups, subnet_ids, scheme='internet-facing'):
vpc_id = None
ec2_backend = ec2_backends[self.region_name]
subnets = []
if not subnet_ids:
raise SubnetNotFoundError()
for subnet_id in subnet_ids:
subnet = ec2_backend.get_subnet(subnet_id)
if subnet is None:
raise SubnetNotFoundError()
subnets.append(subnet)
vpc_id = subnets[0].vpc_id
arn = "arn:aws:elasticloadbalancing:%s:1:loadbalancer/%s/50dc6c495c0c9188" % (self.region_name, name)
dns_name = "%s-1.%s.elb.amazonaws.com" % (name, self.region_name)
if arn in self.load_balancers:
raise DuplicateLoadBalancerName()
new_load_balancer = FakeLoadBalancer(
name=name,
security_groups=security_groups,
arn=arn,
scheme=scheme,
subnets=subnets,
vpc_id=vpc_id,
dns_name=dns_name)
self.load_balancers[arn] = new_load_balancer
return new_load_balancer
def create_rule(self, listener_arn, conditions, priority, actions):
listeners = self.describe_listeners(None, [listener_arn])
if not listeners:
raise ListenerNotFoundError()
listener = listeners[0]
# validate conditions
for condition in conditions:
field = condition['field']
if field not in ['path-pattern', 'host-header']:
raise InvalidConditionFieldError(field)
values = condition['values']
if len(values) == 0:
raise InvalidConditionValueError('A condition value must be specified')
if len(values) > 1:
raise InvalidConditionValueError(
"The '%s' field contains too many values; the limit is '1'" % field
)
# TODO: check pattern of value for 'host-header'
# TODO: check pattern of value for 'path-pattern'
# validate Priority
for rule in listener.rules:
if rule.priority == priority:
raise PriorityInUseError()
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type not in ['forward']:
raise InvalidActionTypeError(action_type, index)
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
# TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules'
# create rule
rule = FakeRule(listener.arn, conditions, priority, actions, is_default=False)
listener.register(rule)
return [rule]
def create_target_group(self, name, **kwargs):
if len(name) > 32:
raise InvalidTargetGroupNameError(
"Target group name '%s' cannot be longer than '32' characters" % name
)
if not re.match('^[a-zA-Z0-9\-]+$', name):
raise InvalidTargetGroupNameError(
"Target group name '%s' can only contain characters that are alphanumeric characters or hyphens(-)" % name
)
# undocumented validation
if not re.match('(?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$', name):
raise InvalidTargetGroupNameError(
"1 validation error detected: Value '%s' at 'targetGroup.targetGroupArn.targetGroupName' failed to satisfy constraint: Member must satisfy regular expression pattern: (?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$" % name
)
if name.startswith('-') or name.endswith('-'):
raise InvalidTargetGroupNameError(
"Target group name '%s' cannot begin or end with '-'" % name
)
for target_group in self.target_groups.values():
if target_group.name == name:
raise DuplicateTargetGroupName()
arn = "arn:aws:elasticloadbalancing:%s:1:targetgroup/%s/50dc6c495c0c9188" % (self.region_name, name)
target_group = FakeTargetGroup(name, arn, **kwargs)
self.target_groups[target_group.arn] = target_group
return target_group
def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions):
balancer = self.load_balancers.get(load_balancer_arn)
if balancer is None:
raise LoadBalancerNotFoundError()
if port in balancer.listeners:
raise DuplicateListenerError()
arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self))
listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions)
balancer.listeners[listener.arn] = listener
return listener
def describe_load_balancers(self, arns, names):
balancers = self.load_balancers.values()
arns = arns or []
names = names or []
if not arns and not names:
return balancers
matched_balancers = []
matched_balancer = None
for arn in arns:
for balancer in balancers:
if balancer.arn == arn:
matched_balancer = balancer
if matched_balancer is None:
raise LoadBalancerNotFoundError()
elif matched_balancer not in matched_balancers:
matched_balancers.append(matched_balancer)
for name in names:
for balancer in balancers:
if balancer.name == name:
matched_balancer = balancer
if matched_balancer is None:
raise LoadBalancerNotFoundError()
elif matched_balancer not in matched_balancers:
matched_balancers.append(matched_balancer)
return matched_balancers
def describe_rules(self, listener_arn, rule_arns):
if listener_arn is None and not rule_arns:
raise InvalidDescribeRulesRequest(
"You must specify either listener rule ARNs or a listener ARN"
)
if listener_arn is not None and rule_arns is not None:
raise InvalidDescribeRulesRequest(
'Listener rule ARNs and a listener ARN cannot be specified at the same time'
)
if listener_arn:
listener = self.describe_listeners(None, [listener_arn])[0]
return listener.rules
# search for rule arns
matched_rules = []
for load_balancer_arn in self.load_balancers:
listeners = self.load_balancers.get(load_balancer_arn).listeners.values()
for listener in listeners:
for rule in listener.rules:
if rule.arn in rule_arns:
matched_rules.append(rule)
return matched_rules
def describe_target_groups(self, load_balancer_arn, target_group_arns, names):
if load_balancer_arn:
if load_balancer_arn not in self.load_balancers:
raise LoadBalancerNotFoundError()
return [tg for tg in self.target_groups.values()
if load_balancer_arn in tg.load_balancer_arns]
if target_group_arns:
try:
return [self.target_groups[arn] for arn in target_group_arns]
except KeyError:
raise TargetGroupNotFoundError()
if names:
matched = []
for name in names:
found = None
for target_group in self.target_groups.values():
if target_group.name == name:
found = target_group
if not found:
raise TargetGroupNotFoundError()
matched.append(found)
return matched
return self.target_groups.values()
def describe_listeners(self, load_balancer_arn, listener_arns):
if load_balancer_arn:
if load_balancer_arn not in self.load_balancers:
raise LoadBalancerNotFoundError()
return self.load_balancers.get(load_balancer_arn).listeners.values()
matched = []
for load_balancer in self.load_balancers.values():
for listener_arn in listener_arns:
listener = load_balancer.listeners.get(listener_arn)
if not listener:
raise ListenerNotFoundError()
matched.append(listener)
return matched
def delete_load_balancer(self, arn):
self.load_balancers.pop(arn, None)
def delete_rule(self, arn):
for load_balancer_arn in self.load_balancers:
listeners = self.load_balancers.get(load_balancer_arn).listeners.values()
for listener in listeners:
for rule in listener.rules:
if rule.arn == arn:
listener.remove_rule(rule)
return
# should raise RuleNotFound Error according to the AWS API doc
# however, boto3 does't raise error even if rule is not found
def delete_target_group(self, target_group_arn):
if target_group_arn not in self.target_groups:
raise TargetGroupNotFoundError()
target_group = self.target_groups[target_group_arn]
if target_group:
if self._any_listener_using(target_group_arn):
raise ResourceInUseError(
"The target group '{}' is currently in use by a listener or a rule".format(
target_group_arn))
del self.target_groups[target_group_arn]
return target_group
def delete_listener(self, listener_arn):
for load_balancer in self.load_balancers.values():
listener = load_balancer.listeners.pop(listener_arn, None)
if listener:
return listener
raise ListenerNotFoundError()
def modify_rule(self, rule_arn, conditions, actions):
# if conditions or actions is empty list, do not update the attributes
if not conditions and not actions:
raise InvalidModifyRuleArgumentsError()
rules = self.describe_rules(listener_arn=None, rule_arns=[rule_arn])
if not rules:
raise RuleNotFoundError()
rule = rules[0]
if conditions:
for condition in conditions:
field = condition['field']
if field not in ['path-pattern', 'host-header']:
raise InvalidConditionFieldError(field)
values = condition['values']
if len(values) == 0:
raise InvalidConditionValueError('A condition value must be specified')
if len(values) > 1:
raise InvalidConditionValueError(
"The '%s' field contains too many values; the limit is '1'" % field
)
# TODO: check pattern of value for 'host-header'
# TODO: check pattern of value for 'path-pattern'
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
if actions:
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type not in ['forward']:
raise InvalidActionTypeError(action_type, index)
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
# TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules'
# modify rule
if conditions:
rule.conditions = conditions
if actions:
rule.actions = actions
return [rule]
def register_targets(self, target_group_arn, instances):
target_group = self.target_groups.get(target_group_arn)
if target_group is None:
raise TargetGroupNotFoundError()
target_group.register(instances)
def deregister_targets(self, target_group_arn, instances):
target_group = self.target_groups.get(target_group_arn)
if target_group is None:
raise TargetGroupNotFoundError()
target_group.deregister(instances)
def describe_target_health(self, target_group_arn, targets):
target_group = self.target_groups.get(target_group_arn)
if target_group is None:
raise TargetGroupNotFoundError()
if not targets:
targets = target_group.targets.values()
return [target_group.health_for(target) for target in targets]
def set_rule_priorities(self, rule_priorities):
# validate
priorities = [rule_priority['priority'] for rule_priority in rule_priorities]
for priority in set(priorities):
if priorities.count(priority) > 1:
raise DuplicatePriorityError(priority)
# validate
for rule_priority in rule_priorities:
given_rule_arn = rule_priority['rule_arn']
priority = rule_priority['priority']
_given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn])
if not _given_rules:
raise RuleNotFoundError()
given_rule = _given_rules[0]
listeners = self.describe_listeners(None, [given_rule.listener_arn])
listener = listeners[0]
for rule_in_listener in listener.rules:
if rule_in_listener.priority == priority:
raise PriorityInUseError()
# modify
modified_rules = []
for rule_priority in rule_priorities:
given_rule_arn = rule_priority['rule_arn']
priority = rule_priority['priority']
_given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn])
if not _given_rules:
raise RuleNotFoundError()
given_rule = _given_rules[0]
given_rule.priority = priority
modified_rules.append(given_rule)
return modified_rules
def _any_listener_using(self, target_group_arn):
for load_balancer in self.load_balancers.values():
for listener in load_balancer.listeners.values():
for rule in listener.rules:
for action in rule.actions:
if action.get('target_group_arn') == target_group_arn:
return True
return False
elbv2_backends = {}
for region in ec2_backends.keys():
elbv2_backends[region] = ELBv2Backend(region)

960
moto/elbv2/responses.py Normal file
View File

@ -0,0 +1,960 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from .models import elbv2_backends
from .exceptions import DuplicateTagKeysError
from .exceptions import LoadBalancerNotFoundError
from .exceptions import TargetGroupNotFoundError
class ELBV2Response(BaseResponse):
@property
def elbv2_backend(self):
return elbv2_backends[self.region]
def create_load_balancer(self):
load_balancer_name = self._get_param('Name')
subnet_ids = self._get_multi_param("Subnets.member")
security_groups = self._get_multi_param("SecurityGroups.member")
scheme = self._get_param('Scheme')
load_balancer = self.elbv2_backend.create_load_balancer(
name=load_balancer_name,
security_groups=security_groups,
subnet_ids=subnet_ids,
scheme=scheme,
)
self._add_tags(load_balancer)
template = self.response_template(CREATE_LOAD_BALANCER_TEMPLATE)
return template.render(load_balancer=load_balancer)
def create_rule(self):
lister_arn = self._get_param('ListenerArn')
_conditions = self._get_list_prefix('Conditions.member')
conditions = []
for _condition in _conditions:
condition = {}
condition['field'] = _condition['field']
values = sorted(
[e for e in _condition.items() if e[0].startswith('values.member')],
key=lambda x: x[0]
)
condition['values'] = [e[1] for e in values]
conditions.append(condition)
priority = self._get_int_param('Priority')
actions = self._get_list_prefix('Actions.member')
rules = self.elbv2_backend.create_rule(
listener_arn=lister_arn,
conditions=conditions,
priority=priority,
actions=actions
)
template = self.response_template(CREATE_RULE_TEMPLATE)
return template.render(rules=rules)
def create_target_group(self):
name = self._get_param('Name')
vpc_id = self._get_param('VpcId')
protocol = self._get_param('Protocol')
port = self._get_param('Port')
healthcheck_protocol = self._get_param('HealthCheckProtocol', 'HTTP')
healthcheck_port = self._get_param('HealthCheckPort', 'traffic-port')
healthcheck_path = self._get_param('HealthCheckPath', '/')
healthcheck_interval_seconds = self._get_param('HealthCheckIntervalSeconds', '30')
healthcheck_timeout_seconds = self._get_param('HealthCheckTimeoutSeconds', '5')
healthy_threshold_count = self._get_param('HealthyThresholdCount', '5')
unhealthy_threshold_count = self._get_param('UnhealthyThresholdCount', '2')
target_group = self.elbv2_backend.create_target_group(
name,
vpc_id=vpc_id,
protocol=protocol,
port=port,
healthcheck_protocol=healthcheck_protocol,
healthcheck_port=healthcheck_port,
healthcheck_path=healthcheck_path,
healthcheck_interval_seconds=healthcheck_interval_seconds,
healthcheck_timeout_seconds=healthcheck_timeout_seconds,
healthy_threshold_count=healthy_threshold_count,
unhealthy_threshold_count=unhealthy_threshold_count,
)
template = self.response_template(CREATE_TARGET_GROUP_TEMPLATE)
return template.render(target_group=target_group)
def create_listener(self):
load_balancer_arn = self._get_param('LoadBalancerArn')
protocol = self._get_param('Protocol')
port = self._get_param('Port')
ssl_policy = self._get_param('SslPolicy', 'ELBSecurityPolicy-2016-08')
certificates = self._get_list_prefix('Certificates.member')
if certificates:
certificate = certificates[0].get('certificate_arn')
else:
certificate = None
default_actions = self._get_list_prefix('DefaultActions.member')
listener = self.elbv2_backend.create_listener(
load_balancer_arn=load_balancer_arn,
protocol=protocol,
port=port,
ssl_policy=ssl_policy,
certificate=certificate,
default_actions=default_actions)
template = self.response_template(CREATE_LISTENER_TEMPLATE)
return template.render(listener=listener)
def describe_load_balancers(self):
arns = self._get_multi_param("LoadBalancerArns.member")
names = self._get_multi_param("Names.member")
all_load_balancers = list(self.elbv2_backend.describe_load_balancers(arns, names))
marker = self._get_param('Marker')
all_names = [balancer.name for balancer in all_load_balancers]
if marker:
start = all_names.index(marker) + 1
else:
start = 0
page_size = self._get_param('PageSize', 50) # the default is 400, but using 50 to make testing easier
load_balancers_resp = all_load_balancers[start:start + page_size]
next_marker = None
if len(all_load_balancers) > start + page_size:
next_marker = load_balancers_resp[-1].name
template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE)
return template.render(load_balancers=load_balancers_resp, marker=next_marker)
def describe_rules(self):
listener_arn = self._get_param('ListenerArn')
rule_arns = self._get_multi_param('RuleArns.member') if any(k for k in list(self.querystring.keys()) if k.startswith('RuleArns.member')) else None
all_rules = self.elbv2_backend.describe_rules(listener_arn, rule_arns)
all_arns = [rule.arn for rule in all_rules]
page_size = self._get_int_param('PageSize', 50) # set 50 for temporary
marker = self._get_param('Marker')
if marker:
start = all_arns.index(marker) + 1
else:
start = 0
rules_resp = all_rules[start:start + page_size]
next_marker = None
if len(all_rules) > start + page_size:
next_marker = rules_resp[-1].arn
template = self.response_template(DESCRIBE_RULES_TEMPLATE)
return template.render(rules=rules_resp, marker=next_marker)
def describe_target_groups(self):
load_balancer_arn = self._get_param('LoadBalancerArn')
target_group_arns = self._get_multi_param('TargetGroupArns.member')
names = self._get_multi_param('Names.member')
target_groups = self.elbv2_backend.describe_target_groups(load_balancer_arn, target_group_arns, names)
template = self.response_template(DESCRIBE_TARGET_GROUPS_TEMPLATE)
return template.render(target_groups=target_groups)
def describe_target_group_attributes(self):
target_group_arn = self._get_param('TargetGroupArn')
target_group = self.elbv2_backend.target_groups.get(target_group_arn)
if not target_group:
raise TargetGroupNotFoundError()
template = self.response_template(DESCRIBE_TARGET_GROUP_ATTRIBUTES_TEMPLATE)
return template.render(attributes=target_group.attributes)
def describe_listeners(self):
load_balancer_arn = self._get_param('LoadBalancerArn')
listener_arns = self._get_multi_param('ListenerArns.member')
if not load_balancer_arn and not listener_arns:
raise LoadBalancerNotFoundError()
listeners = self.elbv2_backend.describe_listeners(load_balancer_arn, listener_arns)
template = self.response_template(DESCRIBE_LISTENERS_TEMPLATE)
return template.render(listeners=listeners)
def delete_load_balancer(self):
arn = self._get_param('LoadBalancerArn')
self.elbv2_backend.delete_load_balancer(arn)
template = self.response_template(DELETE_LOAD_BALANCER_TEMPLATE)
return template.render()
def delete_rule(self):
arn = self._get_param('RuleArn')
self.elbv2_backend.delete_rule(arn)
template = self.response_template(DELETE_RULE_TEMPLATE)
return template.render()
def delete_target_group(self):
arn = self._get_param('TargetGroupArn')
self.elbv2_backend.delete_target_group(arn)
template = self.response_template(DELETE_TARGET_GROUP_TEMPLATE)
return template.render()
def delete_listener(self):
arn = self._get_param('ListenerArn')
self.elbv2_backend.delete_listener(arn)
template = self.response_template(DELETE_LISTENER_TEMPLATE)
return template.render()
def modify_rule(self):
rule_arn = self._get_param('RuleArn')
_conditions = self._get_list_prefix('Conditions.member')
conditions = []
for _condition in _conditions:
condition = {}
condition['field'] = _condition['field']
values = sorted(
[e for e in _condition.items() if e[0].startswith('values.member')],
key=lambda x: x[0]
)
condition['values'] = [e[1] for e in values]
conditions.append(condition)
actions = self._get_list_prefix('Actions.member')
rules = self.elbv2_backend.modify_rule(
rule_arn=rule_arn,
conditions=conditions,
actions=actions
)
template = self.response_template(MODIFY_RULE_TEMPLATE)
return template.render(rules=rules)
def modify_target_group_attributes(self):
target_group_arn = self._get_param('TargetGroupArn')
target_group = self.elbv2_backend.target_groups.get(target_group_arn)
attributes = {
attr['key']: attr['value']
for attr in self._get_list_prefix('Attributes.member')
}
target_group.attributes.update(attributes)
if not target_group:
raise TargetGroupNotFoundError()
template = self.response_template(MODIFY_TARGET_GROUP_ATTRIBUTES_TEMPLATE)
return template.render(attributes=attributes)
def register_targets(self):
target_group_arn = self._get_param('TargetGroupArn')
targets = self._get_list_prefix('Targets.member')
self.elbv2_backend.register_targets(target_group_arn, targets)
template = self.response_template(REGISTER_TARGETS_TEMPLATE)
return template.render()
def deregister_targets(self):
target_group_arn = self._get_param('TargetGroupArn')
targets = self._get_list_prefix('Targets.member')
self.elbv2_backend.deregister_targets(target_group_arn, targets)
template = self.response_template(DEREGISTER_TARGETS_TEMPLATE)
return template.render()
def describe_target_health(self):
target_group_arn = self._get_param('TargetGroupArn')
targets = self._get_list_prefix('Targets.member')
target_health_descriptions = self.elbv2_backend.describe_target_health(target_group_arn, targets)
template = self.response_template(DESCRIBE_TARGET_HEALTH_TEMPLATE)
return template.render(target_health_descriptions=target_health_descriptions)
def set_rule_priorities(self):
rule_priorities = self._get_list_prefix('RulePriorities.member')
for rule_priority in rule_priorities:
rule_priority['priority'] = int(rule_priority['priority'])
rules = self.elbv2_backend.set_rule_priorities(rule_priorities)
template = self.response_template(SET_RULE_PRIORITIES_TEMPLATE)
return template.render(rules=rules)
def add_tags(self):
resource_arns = self._get_multi_param('ResourceArns.member')
for arn in resource_arns:
if ':targetgroup' in arn:
resource = self.elbv2_backend.target_groups.get(arn)
if not resource:
raise TargetGroupNotFoundError()
elif ':loadbalancer' in arn:
resource = self.elbv2_backend.load_balancers.get(arn)
if not resource:
raise LoadBalancerNotFoundError()
else:
raise LoadBalancerNotFoundError()
self._add_tags(resource)
template = self.response_template(ADD_TAGS_TEMPLATE)
return template.render()
def remove_tags(self):
resource_arns = self._get_multi_param('ResourceArns.member')
tag_keys = self._get_multi_param('TagKeys.member')
for arn in resource_arns:
if ':targetgroup' in arn:
resource = self.elbv2_backend.target_groups.get(arn)
if not resource:
raise TargetGroupNotFoundError()
elif ':loadbalancer' in arn:
resource = self.elbv2_backend.load_balancers.get(arn)
if not resource:
raise LoadBalancerNotFoundError()
else:
raise LoadBalancerNotFoundError()
[resource.remove_tag(key) for key in tag_keys]
template = self.response_template(REMOVE_TAGS_TEMPLATE)
return template.render()
def describe_tags(self):
resource_arns = self._get_multi_param('ResourceArns.member')
resources = []
for arn in resource_arns:
if ':targetgroup' in arn:
resource = self.elbv2_backend.target_groups.get(arn)
if not resource:
raise TargetGroupNotFoundError()
elif ':loadbalancer' in arn:
resource = self.elbv2_backend.load_balancers.get(arn)
if not resource:
raise LoadBalancerNotFoundError()
else:
raise LoadBalancerNotFoundError()
resources.append(resource)
template = self.response_template(DESCRIBE_TAGS_TEMPLATE)
return template.render(resources=resources)
def _add_tags(self, resource):
tag_values = []
tag_keys = []
for t_key, t_val in sorted(self.querystring.items()):
if t_key.startswith('Tags.member.'):
if t_key.split('.')[3] == 'Key':
tag_keys.extend(t_val)
elif t_key.split('.')[3] == 'Value':
tag_values.extend(t_val)
counts = {}
for i in tag_keys:
counts[i] = tag_keys.count(i)
counts = sorted(counts.items(), key=lambda i: i[1], reverse=True)
if counts and counts[0][1] > 1:
# We have dupes...
raise DuplicateTagKeysError(counts[0])
for tag_key, tag_value in zip(tag_keys, tag_values):
resource.add_tag(tag_key, tag_value)
ADD_TAGS_TEMPLATE = """<AddTagsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<AddTagsResult/>
<ResponseMetadata>
<RequestId>360e81f7-1100-11e4-b6ed-0f30EXAMPLE</RequestId>
</ResponseMetadata>
</AddTagsResponse>"""
REMOVE_TAGS_TEMPLATE = """<RemoveTagsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<RemoveTagsResult/>
<ResponseMetadata>
<RequestId>360e81f7-1100-11e4-b6ed-0f30EXAMPLE</RequestId>
</ResponseMetadata>
</RemoveTagsResponse>"""
DESCRIBE_TAGS_TEMPLATE = """<DescribeTagsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeTagsResult>
<TagDescriptions>
{% for resource in resources %}
<member>
<ResourceArn>{{ resource.arn }}</ResourceArn>
<Tags>
{% for key, value in resource.tags.items() %}
<member>
<Value>{{ value }}</Value>
<Key>{{ key }}</Key>
</member>
{% endfor %}
</Tags>
</member>
{% endfor %}
</TagDescriptions>
</DescribeTagsResult>
<ResponseMetadata>
<RequestId>360e81f7-1100-11e4-b6ed-0f30EXAMPLE</RequestId>
</ResponseMetadata>
</DescribeTagsResponse>"""
CREATE_LOAD_BALANCER_TEMPLATE = """<CreateLoadBalancerResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<CreateLoadBalancerResult>
<LoadBalancers>
<member>
<LoadBalancerArn>{{ load_balancer.arn }}</LoadBalancerArn>
<Scheme>{{ load_balancer.scheme }}</Scheme>
<LoadBalancerName>{{ load_balancer.name }}</LoadBalancerName>
<VpcId>{{ load_balancer.vpc_id }}</VpcId>
<CanonicalHostedZoneId>Z2P70J7EXAMPLE</CanonicalHostedZoneId>
<CreatedTime>{{ load_balancer.created_time }}</CreatedTime>
<AvailabilityZones>
{% for subnet in load_balancer.subnets %}
<member>
<SubnetId>{{ subnet.id }}</SubnetId>
<ZoneName>{{ subnet.availability_zone }}</ZoneName>
</member>
{% endfor %}
</AvailabilityZones>
<SecurityGroups>
{% for security_group in load_balancer.security_groups %}
<member>{{ security_group }}</member>
{% endfor %}
</SecurityGroups>
<DNSName>{{ load_balancer.dns_name }}</DNSName>
<State>
<Code>provisioning</Code>
</State>
<Type>application</Type>
</member>
</LoadBalancers>
</CreateLoadBalancerResult>
<ResponseMetadata>
<RequestId>32d531b2-f2d0-11e5-9192-3fff33344cfa</RequestId>
</ResponseMetadata>
</CreateLoadBalancerResponse>"""
CREATE_RULE_TEMPLATE = """<CreateRuleResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<CreateRuleResult>
<Rules>
{% for rule in rules %}
<member>
<IsDefault>{{ "true" if rule.is_default else "false" }}</IsDefault>
<Conditions>
{% for condition in rule.conditions %}
<member>
<Field>{{ condition["field"] }}</Field>
<Values>
{% for value in condition["values"] %}
<member>{{ value }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</Conditions>
<Priority>{{ rule.priority }}</Priority>
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member>
{% endfor %}
</Actions>
<RuleArn>{{ rule.arn }}</RuleArn>
</member>
{% endfor %}
</Rules>
</CreateRuleResult>
<ResponseMetadata>
<RequestId>c5478c83-f397-11e5-bb98-57195a6eb84a</RequestId>
</ResponseMetadata>
</CreateRuleResponse>"""
CREATE_TARGET_GROUP_TEMPLATE = """<CreateTargetGroupResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<CreateTargetGroupResult>
<TargetGroups>
<member>
<TargetGroupArn>{{ target_group.arn }}</TargetGroupArn>
<TargetGroupName>{{ target_group.name }}</TargetGroupName>
<Protocol>{{ target_group.protocol }}</Protocol>
<Port>{{ target_group.port }}</Port>
<VpcId>{{ target_group.vpc_id }}</VpcId>
<HealthCheckProtocol>{{ target_group.health_check_protocol }}</HealthCheckProtocol>
<HealthCheckPort>{{ target_group.healthcheck_port }}</HealthCheckPort>
<HealthCheckPath>{{ target_group.healthcheck_path }}</HealthCheckPath>
<HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds>
<HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds>
<HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount>
<UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount>
<Matcher>
<HttpCode>200</HttpCode>
</Matcher>
</member>
</TargetGroups>
</CreateTargetGroupResult>
<ResponseMetadata>
<RequestId>b83fe90e-f2d5-11e5-b95d-3b2c1831fc26</RequestId>
</ResponseMetadata>
</CreateTargetGroupResponse>"""
CREATE_LISTENER_TEMPLATE = """<CreateListenerResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<CreateListenerResult>
<Listeners>
<member>
<LoadBalancerArn>{{ listener.load_balancer_arn }}</LoadBalancerArn>
<Protocol>{{ listener.protocol }}</Protocol>
{% if listener.certificate %}
<Certificates>
<member>
<CertificateArn>{{ listener.certificate }}</CertificateArn>
</member>
</Certificates>
{% endif %}
<Port>{{ listener.port }}</Port>
<SslPolicy>{{ listener.ssl_policy }}</SslPolicy>
<ListenerArn>{{ listener.arn }}</ListenerArn>
<DefaultActions>
{% for action in listener.default_actions %}
<member>
<Type>{{ action.type }}</Type>
<TargetGroupArn>{{ action.target_group_arn }}</TargetGroupArn>
</member>
{% endfor %}
</DefaultActions>
</member>
</Listeners>
</CreateListenerResult>
<ResponseMetadata>
<RequestId>97f1bb38-f390-11e5-b95d-3b2c1831fc26</RequestId>
</ResponseMetadata>
</CreateListenerResponse>"""
DELETE_LOAD_BALANCER_TEMPLATE = """<DeleteLoadBalancerResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DeleteLoadBalancerResult/>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
</ResponseMetadata>
</DeleteLoadBalancerResponse>"""
DELETE_RULE_TEMPLATE = """<DeleteRuleResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DeleteRuleResult/>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
</ResponseMetadata>
</DeleteRuleResponse>"""
DELETE_TARGET_GROUP_TEMPLATE = """<DeleteTargetGroupResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DeleteTargetGroupResult/>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
</ResponseMetadata>
</DeleteTargetGroupResponse>"""
DELETE_LISTENER_TEMPLATE = """<DeleteListenerResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DeleteListenerResult/>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
</ResponseMetadata>
</DeleteListenerResponse>"""
DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeLoadBalancersResult>
<LoadBalancers>
{% for load_balancer in load_balancers %}
<member>
<LoadBalancerArn>{{ load_balancer.arn }}</LoadBalancerArn>
<Scheme>{{ load_balancer.scheme }}</Scheme>
<LoadBalancerName>{{ load_balancer.name }}</LoadBalancerName>
<VpcId>{{ load_balancer.vpc_id }}</VpcId>
<CanonicalHostedZoneId>Z2P70J7EXAMPLE</CanonicalHostedZoneId>
<CreatedTime>{{ load_balancer.created_time }}</CreatedTime>
<AvailabilityZones>
{% for subnet in load_balancer.subnets %}
<member>
<SubnetId>{{ subnet.id }}</SubnetId>
<ZoneName>{{ subnet.availability_zone }}</ZoneName>
</member>
{% endfor %}
</AvailabilityZones>
<SecurityGroups>
{% for security_group in load_balancer.security_groups %}
<member>{{ security_group }}</member>
{% endfor %}
</SecurityGroups>
<DNSName>{{ load_balancer.dns_name }}</DNSName>
<State>
<Code>provisioning</Code>
</State>
<Type>application</Type>
</member>
{% endfor %}
</LoadBalancers>
{% if marker %}
<NextMarker>{{ marker }}</NextMarker>
{% endif %}
</DescribeLoadBalancersResult>
<ResponseMetadata>
<RequestId>f9880f01-7852-629d-a6c3-3ae2-666a409287e6dc0c</RequestId>
</ResponseMetadata>
</DescribeLoadBalancersResponse>"""
DESCRIBE_RULES_TEMPLATE = """<DescribeRulesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeRulesResult>
<Rules>
{% for rule in rules %}
<member>
<IsDefault>{{ "true" if rule.is_default else "false" }}</IsDefault>
<Conditions>
{% for condition in rule.conditions %}
<member>
<Field>{{ condition["field"] }}</Field>
<Values>
{% for value in condition["values"] %}
<member>{{ value }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</Conditions>
<Priority>{{ rule.priority }}</Priority>
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member>
{% endfor %}
</Actions>
<RuleArn>{{ rule.arn }}</RuleArn>
</member>
{% endfor %}
</Rules>
{% if marker %}
<NextMarker>{{ marker }}</NextMarker>
{% endif %}
</DescribeRulesResult>
<ResponseMetadata>
<RequestId>74926cf3-f3a3-11e5-b543-9f2c3fbb9bee</RequestId>
</ResponseMetadata>
</DescribeRulesResponse>"""
DESCRIBE_TARGET_GROUPS_TEMPLATE = """<DescribeTargetGroupsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeTargetGroupsResult>
<TargetGroups>
{% for target_group in target_groups %}
<member>
<TargetGroupArn>{{ target_group.arn }}</TargetGroupArn>
<TargetGroupName>{{ target_group.name }}</TargetGroupName>
<Protocol>{{ target_group.protocol }}</Protocol>
<Port>{{ target_group.port }}</Port>
<VpcId>{{ target_group.vpc_id }}</VpcId>
<HealthCheckProtocol>{{ target_group.health_check_protocol }}</HealthCheckProtocol>
<HealthCheckPort>{{ target_group.healthcheck_port }}</HealthCheckPort>
<HealthCheckPath>{{ target_group.healthcheck_path }}</HealthCheckPath>
<HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds>
<HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds>
<HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount>
<UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount>
<Matcher>
<HttpCode>200</HttpCode>
</Matcher>
<LoadBalancerArns>
{% for load_balancer_arn in target_group.load_balancer_arns %}
<member>{{ load_balancer_arn }}</member>
{% endfor %}
</LoadBalancerArns>
</member>
{% endfor %}
</TargetGroups>
</DescribeTargetGroupsResult>
<ResponseMetadata>
<RequestId>70092c0e-f3a9-11e5-ae48-cff02092876b</RequestId>
</ResponseMetadata>
</DescribeTargetGroupsResponse>"""
DESCRIBE_TARGET_GROUP_ATTRIBUTES_TEMPLATE = """<DescribeTargetGroupAttributesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeTargetGroupAttributesResult>
<Attributes>
{% for key, value in attributes.items() %}
<member>
<Key>{{ key }}</Key>
<Value>{{ value }}</Value>
</member>
{% endfor %}
</Attributes>
</DescribeTargetGroupAttributesResult>
<ResponseMetadata>
<RequestId>70092c0e-f3a9-11e5-ae48-cff02092876b</RequestId>
</ResponseMetadata>
</DescribeTargetGroupAttributesResponse>"""
DESCRIBE_LISTENERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeListenersResult>
<Listeners>
{% for listener in listeners %}
<member>
<LoadBalancerArn>{{ listener.load_balancer_arn }}</LoadBalancerArn>
<Protocol>{{ listener.protocol }}</Protocol>
{% if listener.certificate %}
<Certificates>
<member>
<CertificateArn>{{ listener.certificate }}</CertificateArn>
</member>
</Certificates>
{% endif %}
<Port>{{ listener.port }}</Port>
<SslPolicy>{{ listener.ssl_policy }}</SslPolicy>
<ListenerArn>{{ listener.arn }}</ListenerArn>
<DefaultActions>
{% for action in listener.default_actions %}
<member>
<Type>{{ action.type }}</Type>
<TargetGroupArn>{{ action.target_group_arn }}</TargetGroupArn>
</member>
{% endfor %}
</DefaultActions>
</member>
{% endfor %}
</Listeners>
</DescribeListenersResult>
<ResponseMetadata>
<RequestId>65a3a7ea-f39c-11e5-b543-9f2c3fbb9bee</RequestId>
</ResponseMetadata>
</DescribeLoadBalancersResponse>"""
CONFIGURE_HEALTH_CHECK_TEMPLATE = """<ConfigureHealthCheckResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<ConfigureHealthCheckResult>
<HealthCheck>
<Interval>{{ check.interval }}</Interval>
<Target>{{ check.target }}</Target>
<HealthyThreshold>{{ check.healthy_threshold }}</HealthyThreshold>
<Timeout>{{ check.timeout }}</Timeout>
<UnhealthyThreshold>{{ check.unhealthy_threshold }}</UnhealthyThreshold>
</HealthCheck>
</ConfigureHealthCheckResult>
<ResponseMetadata>
<RequestId>f9880f01-7852-629d-a6c3-3ae2-666a409287e6dc0c</RequestId>
</ResponseMetadata>
</ConfigureHealthCheckResponse>"""
MODIFY_RULE_TEMPLATE = """<ModifyRuleResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<ModifyRuleResult>
<Rules>
{% for rule in rules %}
<member>
<IsDefault>{{ "true" if rule.is_default else "false" }}</IsDefault>
<Conditions>
{% for condition in rule.conditions %}
<member>
<Field>{{ condition["field"] }}</Field>
<Values>
{% for value in condition["values"] %}
<member>{{ value }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</Conditions>
<Priority>{{ rule.priority }}</Priority>
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member>
{% endfor %}
</Actions>
<RuleArn>{{ rule.arn }}</RuleArn>
</member>
{% endfor %}
</Rules>
</ModifyRuleResult>
<ResponseMetadata>
<RequestId>c5478c83-f397-11e5-bb98-57195a6eb84a</RequestId>
</ResponseMetadata>
</ModifyRuleResponse>"""
MODIFY_TARGET_GROUP_ATTRIBUTES_TEMPLATE = """<ModifyTargetGroupAttributesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<ModifyTargetGroupAttributesResult>
<Attributes>
{% for key, value in attributes.items() %}
<member>
<Key>{{ key }}</Key>
<Value>{{ value }}</Value>
</member>
{% endfor %}
</Attributes>
</ModifyTargetGroupAttributesResult>
<ResponseMetadata>
<RequestId>70092c0e-f3a9-11e5-ae48-cff02092876b</RequestId>
</ResponseMetadata>
</ModifyTargetGroupAttributesResponse>"""
REGISTER_TARGETS_TEMPLATE = """<RegisterTargetsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<RegisterTargetsResult>
</RegisterTargetsResult>
<ResponseMetadata>
<RequestId>f9880f01-7852-629d-a6c3-3ae2-666a409287e6dc0c</RequestId>
</ResponseMetadata>
</RegisterTargetsResponse>"""
DEREGISTER_TARGETS_TEMPLATE = """<DeregisterTargetsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DeregisterTargetsResult>
</DeregisterTargetsResult>
<ResponseMetadata>
<RequestId>f9880f01-7852-629d-a6c3-3ae2-666a409287e6dc0c</RequestId>
</ResponseMetadata>
</DeregisterTargetsResponse>"""
SET_LOAD_BALANCER_SSL_CERTIFICATE = """<SetLoadBalancerListenerSSLCertificateResponse xmlns="http://elasticloadbalan cing.amazonaws.com/doc/2015-12-01/">
<SetLoadBalancerListenerSSLCertificateResult/>
<ResponseMetadata>
<RequestId>83c88b9d-12b7-11e3-8b82-87b12EXAMPLE</RequestId>
</ResponseMetadata>
</SetLoadBalancerListenerSSLCertificateResponse>"""
DELETE_LOAD_BALANCER_LISTENERS = """<DeleteLoadBalancerListenersResponse xmlns="http://elasticloadbalan cing.amazonaws.com/doc/2015-12-01/">
<DeleteLoadBalancerListenersResult/>
<ResponseMetadata>
<RequestId>83c88b9d-12b7-11e3-8b82-87b12EXAMPLE</RequestId>
</ResponseMetadata>
</DeleteLoadBalancerListenersResponse>"""
DESCRIBE_ATTRIBUTES_TEMPLATE = """<DescribeLoadBalancerAttributesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeLoadBalancerAttributesResult>
<LoadBalancerAttributes>
<AccessLog>
<Enabled>{{ attributes.access_log.enabled }}</Enabled>
{% if attributes.access_log.enabled %}
<S3BucketName>{{ attributes.access_log.s3_bucket_name }}</S3BucketName>
<S3BucketPrefix>{{ attributes.access_log.s3_bucket_prefix }}</S3BucketPrefix>
<EmitInterval>{{ attributes.access_log.emit_interval }}</EmitInterval>
{% endif %}
</AccessLog>
<ConnectionSettings>
<IdleTimeout>{{ attributes.connecting_settings.idle_timeout }}</IdleTimeout>
</ConnectionSettings>
<CrossZoneLoadBalancing>
<Enabled>{{ attributes.cross_zone_load_balancing.enabled }}</Enabled>
</CrossZoneLoadBalancing>
<ConnectionDraining>
{% if attributes.connection_draining.enabled %}
<Enabled>true</Enabled>
<Timeout>{{ attributes.connection_draining.timeout }}</Timeout>
{% else %}
<Enabled>false</Enabled>
{% endif %}
</ConnectionDraining>
</LoadBalancerAttributes>
</DescribeLoadBalancerAttributesResult>
<ResponseMetadata>
<RequestId>83c88b9d-12b7-11e3-8b82-87b12EXAMPLE</RequestId>
</ResponseMetadata>
</DescribeLoadBalancerAttributesResponse>
"""
MODIFY_ATTRIBUTES_TEMPLATE = """<ModifyLoadBalancerAttributesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<ModifyLoadBalancerAttributesResult>
<LoadBalancerName>{{ load_balancer.name }}</LoadBalancerName>
<LoadBalancerAttributes>
<AccessLog>
<Enabled>{{ attributes.access_log.enabled }}</Enabled>
{% if attributes.access_log.enabled %}
<S3BucketName>{{ attributes.access_log.s3_bucket_name }}</S3BucketName>
<S3BucketPrefix>{{ attributes.access_log.s3_bucket_prefix }}</S3BucketPrefix>
<EmitInterval>{{ attributes.access_log.emit_interval }}</EmitInterval>
{% endif %}
</AccessLog>
<ConnectionSettings>
<IdleTimeout>{{ attributes.connecting_settings.idle_timeout }}</IdleTimeout>
</ConnectionSettings>
<CrossZoneLoadBalancing>
<Enabled>{{ attributes.cross_zone_load_balancing.enabled }}</Enabled>
</CrossZoneLoadBalancing>
<ConnectionDraining>
{% if attributes.connection_draining.enabled %}
<Enabled>true</Enabled>
<Timeout>{{ attributes.connection_draining.timeout }}</Timeout>
{% else %}
<Enabled>false</Enabled>
{% endif %}
</ConnectionDraining>
</LoadBalancerAttributes>
</ModifyLoadBalancerAttributesResult>
<ResponseMetadata>
<RequestId>83c88b9d-12b7-11e3-8b82-87b12EXAMPLE</RequestId>
</ResponseMetadata>
</ModifyLoadBalancerAttributesResponse>
"""
CREATE_LOAD_BALANCER_POLICY_TEMPLATE = """<CreateLoadBalancerPolicyResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<CreateLoadBalancerPolicyResult/>
<ResponseMetadata>
<RequestId>83c88b9d-12b7-11e3-8b82-87b12EXAMPLE</RequestId>
</ResponseMetadata>
</CreateLoadBalancerPolicyResponse>
"""
SET_LOAD_BALANCER_POLICIES_OF_LISTENER_TEMPLATE = """<SetLoadBalancerPoliciesOfListenerResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<SetLoadBalancerPoliciesOfListenerResult/>
<ResponseMetadata>
<RequestId>07b1ecbc-1100-11e3-acaf-dd7edEXAMPLE</RequestId>
</ResponseMetadata>
</SetLoadBalancerPoliciesOfListenerResponse>
"""
SET_LOAD_BALANCER_POLICIES_FOR_BACKEND_SERVER_TEMPLATE = """<SetLoadBalancerPoliciesForBackendServerResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<SetLoadBalancerPoliciesForBackendServerResult/>
<ResponseMetadata>
<RequestId>0eb9b381-dde0-11e2-8d78-6ddbaEXAMPLE</RequestId>
</ResponseMetadata>
</SetLoadBalancerPoliciesForBackendServerResponse>
"""
DESCRIBE_TARGET_HEALTH_TEMPLATE = """<DescribeTargetHealthResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeTargetHealthResult>
<TargetHealthDescriptions>
{% for target_health in target_health_descriptions %}
<member>
<HealthCheckPort>{{ target_health.health_port }}</HealthCheckPort>
<TargetHealth>
<State>{{ target_health.status }}</State>
</TargetHealth>
<Target>
<Port>{{ target_health.port }}</Port>
<Id>{{ target_health.instance_id }}</Id>
</Target>
</member>
{% endfor %}
</TargetHealthDescriptions>
</DescribeTargetHealthResult>
<ResponseMetadata>
<RequestId>c534f810-f389-11e5-9192-3fff33344cfa</RequestId>
</ResponseMetadata>
</DescribeTargetHealthResponse>"""
SET_RULE_PRIORITIES_TEMPLATE = """<SetRulePrioritiesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<SetRulePrioritiesResult>
<Rules>
{% for rule in rules %}
<member>
<IsDefault>{{ "true" if rule.is_default else "false" }}</IsDefault>
<Conditions>
{% for condition in rule.conditions %}
<member>
<Field>{{ condition["field"] }}</Field>
<Values>
{% for value in condition["values"] %}
<member>{{ value }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</Conditions>
<Priority>{{ rule.priority }}</Priority>
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member>
{% endfor %}
</Actions>
<RuleArn>{{ rule.arn }}</RuleArn>
</member>
{% endfor %}
</Rules>
</SetRulePrioritiesResult>
<ResponseMetadata>
<RequestId>4d7a8036-f3a7-11e5-9c02-8fd20490d5a6</RequestId>
</ResponseMetadata>
</SetRulePrioritiesResponse>"""

10
moto/elbv2/urls.py Normal file
View File

@ -0,0 +1,10 @@
from __future__ import unicode_literals
from ..elb.urls import api_version_elb_backend
url_bases = [
"https?://elasticloadbalancing.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': api_version_elb_backend,
}

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import base64 import base64
from datetime import datetime from datetime import datetime
import json
import pytz import pytz
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.core.utils import iso_8601_datetime_without_milliseconds
from .aws_managed_policies import aws_managed_policies_data
from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException
from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id
@ -72,14 +74,32 @@ class ManagedPolicy(Policy):
is_attachable = True is_attachable = True
def attach_to_role(self, role): def attach_to(self, obj):
self.attachment_count += 1 self.attachment_count += 1
role.managed_policies[self.name] = self obj.managed_policies[self.name] = self
def detach_from(self, obj):
self.attachment_count -= 1
del obj.managed_policies[self.name]
class AWSManagedPolicy(ManagedPolicy): class AWSManagedPolicy(ManagedPolicy):
"""AWS-managed policy.""" """AWS-managed policy."""
@classmethod
def from_data(cls, name, data):
return cls(name,
default_version_id=data.get('DefaultVersionId'),
path=data.get('Path'),
document=data.get('Document'))
# AWS defines some of its own managed policies and we periodically
# import them via `make aws_managed_policies`
aws_managed_policies = [
AWSManagedPolicy.from_data(name, d) for name, d
in json.loads(aws_managed_policies_data).items()]
class InlinePolicy(Policy): class InlinePolicy(Policy):
"""TODO: is this needed?""" """TODO: is this needed?"""
@ -120,6 +140,13 @@ class Role(BaseModel):
def put_policy(self, policy_name, policy_json): def put_policy(self, policy_name, policy_json):
self.policies[policy_name] = policy_json self.policies[policy_name] = policy_json
def delete_policy(self, policy_name):
try:
del self.policies[policy_name]
except KeyError:
raise IAMNotFoundException(
"The role policy with name {0} cannot be found.".format(policy_name))
@property @property
def physical_resource_id(self): def physical_resource_id(self):
return self.id return self.id
@ -214,6 +241,7 @@ class Group(BaseModel):
) )
self.users = [] self.users = []
self.managed_policies = {}
self.policies = {} self.policies = {}
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
@ -254,8 +282,10 @@ class User(BaseModel):
self.created = datetime.utcnow() self.created = datetime.utcnow()
self.mfa_devices = {} self.mfa_devices = {}
self.policies = {} self.policies = {}
self.managed_policies = {}
self.access_keys = [] self.access_keys = []
self.password = None self.password = None
self.password_reset_required = False
@property @property
def arn(self): def arn(self):
@ -367,115 +397,6 @@ class User(BaseModel):
) )
# predefine AWS managed policies
aws_managed_policies = [
AWSManagedPolicy(
'AmazonElasticMapReduceRole',
default_version_id='v6',
path='/service-role/',
document={
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Resource": "*",
"Action": [
"ec2:AuthorizeSecurityGroupEgress",
"ec2:AuthorizeSecurityGroupIngress",
"ec2:CancelSpotInstanceRequests",
"ec2:CreateNetworkInterface",
"ec2:CreateSecurityGroup",
"ec2:CreateTags",
"ec2:DeleteNetworkInterface",
"ec2:DeleteSecurityGroup",
"ec2:DeleteTags",
"ec2:DescribeAvailabilityZones",
"ec2:DescribeAccountAttributes",
"ec2:DescribeDhcpOptions",
"ec2:DescribeInstanceStatus",
"ec2:DescribeInstances",
"ec2:DescribeKeyPairs",
"ec2:DescribeNetworkAcls",
"ec2:DescribeNetworkInterfaces",
"ec2:DescribePrefixLists",
"ec2:DescribeRouteTables",
"ec2:DescribeSecurityGroups",
"ec2:DescribeSpotInstanceRequests",
"ec2:DescribeSpotPriceHistory",
"ec2:DescribeSubnets",
"ec2:DescribeVpcAttribute",
"ec2:DescribeVpcEndpoints",
"ec2:DescribeVpcEndpointServices",
"ec2:DescribeVpcs",
"ec2:DetachNetworkInterface",
"ec2:ModifyImageAttribute",
"ec2:ModifyInstanceAttribute",
"ec2:RequestSpotInstances",
"ec2:RevokeSecurityGroupEgress",
"ec2:RunInstances",
"ec2:TerminateInstances",
"ec2:DeleteVolume",
"ec2:DescribeVolumeStatus",
"ec2:DescribeVolumes",
"ec2:DetachVolume",
"iam:GetRole",
"iam:GetRolePolicy",
"iam:ListInstanceProfiles",
"iam:ListRolePolicies",
"iam:PassRole",
"s3:CreateBucket",
"s3:Get*",
"s3:List*",
"sdb:BatchPutAttributes",
"sdb:Select",
"sqs:CreateQueue",
"sqs:Delete*",
"sqs:GetQueue*",
"sqs:PurgeQueue",
"sqs:ReceiveMessage"
]
}]
}
),
AWSManagedPolicy(
'AmazonElasticMapReduceforEC2Role',
default_version_id='v2',
path='/service-role/',
document={
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Resource": "*",
"Action": [
"cloudwatch:*",
"dynamodb:*",
"ec2:Describe*",
"elasticmapreduce:Describe*",
"elasticmapreduce:ListBootstrapActions",
"elasticmapreduce:ListClusters",
"elasticmapreduce:ListInstanceGroups",
"elasticmapreduce:ListInstances",
"elasticmapreduce:ListSteps",
"kinesis:CreateStream",
"kinesis:DeleteStream",
"kinesis:DescribeStream",
"kinesis:GetRecords",
"kinesis:GetShardIterator",
"kinesis:MergeShards",
"kinesis:PutRecord",
"kinesis:SplitShard",
"rds:Describe*",
"s3:*",
"sdb:*",
"sns:*",
"sqs:*"
]
}]
}
)
]
# TODO: add more predefined AWS managed policies
class IAMBackend(BaseBackend): class IAMBackend(BaseBackend):
def __init__(self): def __init__(self):
@ -486,6 +407,7 @@ class IAMBackend(BaseBackend):
self.users = {} self.users = {}
self.credential_report = None self.credential_report = None
self.managed_policies = self._init_managed_policies() self.managed_policies = self._init_managed_policies()
self.account_aliases = []
super(IAMBackend, self).__init__() super(IAMBackend, self).__init__()
def _init_managed_policies(self): def _init_managed_policies(self):
@ -494,7 +416,47 @@ class IAMBackend(BaseBackend):
def attach_role_policy(self, policy_arn, role_name): def attach_role_policy(self, policy_arn, role_name):
arns = dict((p.arn, p) for p in self.managed_policies.values()) arns = dict((p.arn, p) for p in self.managed_policies.values())
policy = arns[policy_arn] policy = arns[policy_arn]
policy.attach_to_role(self.get_role(role_name)) policy.attach_to(self.get_role(role_name))
def detach_role_policy(self, policy_arn, role_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
policy.detach_from(self.get_role(role_name))
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
def attach_group_policy(self, policy_arn, group_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.attach_to(self.get_group(group_name))
def detach_group_policy(self, policy_arn, group_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.detach_from(self.get_group(group_name))
def attach_user_policy(self, policy_arn, user_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.attach_to(self.get_user(user_name))
def detach_user_policy(self, policy_arn, user_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.detach_from(self.get_user(user_name))
def create_policy(self, description, path, policy_document, policy_name): def create_policy(self, description, path, policy_document, policy_name):
policy = ManagedPolicy( policy = ManagedPolicy(
@ -511,21 +473,15 @@ class IAMBackend(BaseBackend):
def list_attached_role_policies(self, role_name, marker=None, max_items=100, path_prefix='/'): def list_attached_role_policies(self, role_name, marker=None, max_items=100, path_prefix='/'):
policies = self.get_role(role_name).managed_policies.values() policies = self.get_role(role_name).managed_policies.values()
return self._filter_attached_policies(policies, marker, max_items, path_prefix)
if path_prefix: def list_attached_group_policies(self, group_name, marker=None, max_items=100, path_prefix='/'):
policies = [p for p in policies if p.path.startswith(path_prefix)] policies = self.get_group(group_name).managed_policies.values()
return self._filter_attached_policies(policies, marker, max_items, path_prefix)
policies = sorted(policies, key=lambda policy: policy.name) def list_attached_user_policies(self, user_name, marker=None, max_items=100, path_prefix='/'):
start_idx = int(marker) if marker else 0 policies = self.get_user(user_name).managed_policies.values()
return self._filter_attached_policies(policies, marker, max_items, path_prefix)
policies = policies[start_idx:start_idx + max_items]
if len(policies) < max_items:
marker = None
else:
marker = str(start_idx + max_items)
return policies, marker
def list_policies(self, marker, max_items, only_attached, path_prefix, scope): def list_policies(self, marker, max_items, only_attached, path_prefix, scope):
policies = self.managed_policies.values() policies = self.managed_policies.values()
@ -539,6 +495,9 @@ class IAMBackend(BaseBackend):
policies = [p for p in policies if not isinstance( policies = [p for p in policies if not isinstance(
p, AWSManagedPolicy)] p, AWSManagedPolicy)]
return self._filter_attached_policies(policies, marker, max_items, path_prefix)
def _filter_attached_policies(self, policies, marker, max_items, path_prefix):
if path_prefix: if path_prefix:
policies = [p for p in policies if p.path.startswith(path_prefix)] policies = [p for p in policies if p.path.startswith(path_prefix)]
@ -569,6 +528,12 @@ class IAMBackend(BaseBackend):
return role return role
raise IAMNotFoundException("Role {0} not found".format(role_name)) raise IAMNotFoundException("Role {0} not found".format(role_name))
def get_role_by_arn(self, arn):
for role in self.get_roles():
if role.arn == arn:
return role
raise IAMNotFoundException("Role {0} not found".format(arn))
def delete_role(self, role_name): def delete_role(self, role_name):
for role in self.get_roles(): for role in self.get_roles():
if role.name == role_name: if role.name == role_name:
@ -583,6 +548,10 @@ class IAMBackend(BaseBackend):
role = self.get_role(role_name) role = self.get_role(role_name)
role.put_policy(policy_name, policy_json) role.put_policy(policy_name, policy_json)
def delete_role_policy(self, role_name, policy_name):
role = self.get_role(role_name)
role.delete_policy(policy_name)
def get_role_policy(self, role_name, policy_name): def get_role_policy(self, role_name, policy_name):
role = self.get_role(role_name) role = self.get_role(role_name)
for p, d in role.policies.items(): for p, d in role.policies.items():
@ -772,6 +741,24 @@ class IAMBackend(BaseBackend):
raise IAMConflictException( raise IAMConflictException(
"User {0} already has password".format(user_name)) "User {0} already has password".format(user_name))
user.password = password user.password = password
return user
def get_login_profile(self, user_name):
user = self.get_user(user_name)
if not user.password:
raise IAMNotFoundException(
"Login profile for {0} not found".format(user_name))
return user
def update_login_profile(self, user_name, password, password_reset_required):
# This does not currently deal with PasswordPolicyViolation.
user = self.get_user(user_name)
if not user.password:
raise IAMNotFoundException(
"Login profile for {0} not found".format(user_name))
user.password = password
user.password_reset_required = password_reset_required
return user
def delete_login_profile(self, user_name): def delete_login_profile(self, user_name):
user = self.get_user(user_name) user = self.get_user(user_name)
@ -878,5 +865,15 @@ class IAMBackend(BaseBackend):
report += self.users[user].to_csv() report += self.users[user].to_csv()
return base64.b64encode(report.encode('ascii')).decode('ascii') return base64.b64encode(report.encode('ascii')).decode('ascii')
def list_account_aliases(self):
return self.account_aliases
def create_account_alias(self, alias):
# alias is force updated
self.account_aliases = [alias]
def delete_account_alias(self, alias):
self.account_aliases = []
iam_backend = IAMBackend() iam_backend = IAMBackend()

View File

@ -13,6 +13,41 @@ class IamResponse(BaseResponse):
template = self.response_template(ATTACH_ROLE_POLICY_TEMPLATE) template = self.response_template(ATTACH_ROLE_POLICY_TEMPLATE)
return template.render() return template.render()
def detach_role_policy(self):
role_name = self._get_param('RoleName')
policy_arn = self._get_param('PolicyArn')
iam_backend.detach_role_policy(policy_arn, role_name)
template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="DetachRolePolicyResponse")
def attach_group_policy(self):
policy_arn = self._get_param('PolicyArn')
group_name = self._get_param('GroupName')
iam_backend.attach_group_policy(policy_arn, group_name)
template = self.response_template(ATTACH_GROUP_POLICY_TEMPLATE)
return template.render()
def detach_group_policy(self):
policy_arn = self._get_param('PolicyArn')
group_name = self._get_param('GroupName')
iam_backend.detach_group_policy(policy_arn, group_name)
template = self.response_template(DETACH_GROUP_POLICY_TEMPLATE)
return template.render()
def attach_user_policy(self):
policy_arn = self._get_param('PolicyArn')
user_name = self._get_param('UserName')
iam_backend.attach_user_policy(policy_arn, user_name)
template = self.response_template(ATTACH_USER_POLICY_TEMPLATE)
return template.render()
def detach_user_policy(self):
policy_arn = self._get_param('PolicyArn')
user_name = self._get_param('UserName')
iam_backend.detach_user_policy(policy_arn, user_name)
template = self.response_template(DETACH_USER_POLICY_TEMPLATE)
return template.render()
def create_policy(self): def create_policy(self):
description = self._get_param('Description') description = self._get_param('Description')
path = self._get_param('Path') path = self._get_param('Path')
@ -33,6 +68,28 @@ class IamResponse(BaseResponse):
template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE) template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE)
return template.render(policies=policies, marker=marker) return template.render(policies=policies, marker=marker)
def list_attached_group_policies(self):
marker = self._get_param('Marker')
max_items = self._get_int_param('MaxItems', 100)
path_prefix = self._get_param('PathPrefix', '/')
group_name = self._get_param('GroupName')
policies, marker = iam_backend.list_attached_group_policies(
group_name, marker=marker, max_items=max_items,
path_prefix=path_prefix)
template = self.response_template(LIST_ATTACHED_GROUP_POLICIES_TEMPLATE)
return template.render(policies=policies, marker=marker)
def list_attached_user_policies(self):
marker = self._get_param('Marker')
max_items = self._get_int_param('MaxItems', 100)
path_prefix = self._get_param('PathPrefix', '/')
user_name = self._get_param('UserName')
policies, marker = iam_backend.list_attached_user_policies(
user_name, marker=marker, max_items=max_items,
path_prefix=path_prefix)
template = self.response_template(LIST_ATTACHED_USER_POLICIES_TEMPLATE)
return template.render(policies=policies, marker=marker)
def list_policies(self): def list_policies(self):
marker = self._get_param('Marker') marker = self._get_param('Marker')
max_items = self._get_int_param('MaxItems', 100) max_items = self._get_int_param('MaxItems', 100)
@ -82,6 +139,13 @@ class IamResponse(BaseResponse):
template = self.response_template(GENERIC_EMPTY_TEMPLATE) template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="PutRolePolicyResponse") return template.render(name="PutRolePolicyResponse")
def delete_role_policy(self):
role_name = self._get_param('RoleName')
policy_name = self._get_param('PolicyName')
iam_backend.delete_role_policy(role_name, policy_name)
template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="DeleteRolePolicyResponse")
def get_role_policy(self): def get_role_policy(self):
role_name = self._get_param('RoleName') role_name = self._get_param('RoleName')
policy_name = self._get_param('PolicyName') policy_name = self._get_param('PolicyName')
@ -290,10 +354,27 @@ class IamResponse(BaseResponse):
def create_login_profile(self): def create_login_profile(self):
user_name = self._get_param('UserName') user_name = self._get_param('UserName')
password = self._get_param('Password') password = self._get_param('Password')
iam_backend.create_login_profile(user_name, password) password = self._get_param('Password')
user = iam_backend.create_login_profile(user_name, password)
template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE) template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE)
return template.render(user_name=user_name) return template.render(user=user)
def get_login_profile(self):
user_name = self._get_param('UserName')
user = iam_backend.get_login_profile(user_name)
template = self.response_template(GET_LOGIN_PROFILE_TEMPLATE)
return template.render(user=user)
def update_login_profile(self):
user_name = self._get_param('UserName')
password = self._get_param('Password')
password_reset_required = self._get_param('PasswordResetRequired')
user = iam_backend.update_login_profile(user_name, password, password_reset_required)
template = self.response_template(UPDATE_LOGIN_PROFILE_TEMPLATE)
return template.render(user=user)
def add_user_to_group(self): def add_user_to_group(self):
group_name = self._get_param('GroupName') group_name = self._get_param('GroupName')
@ -422,6 +503,23 @@ class IamResponse(BaseResponse):
template = self.response_template(CREDENTIAL_REPORT) template = self.response_template(CREDENTIAL_REPORT)
return template.render(report=report) return template.render(report=report)
def list_account_aliases(self):
aliases = iam_backend.list_account_aliases()
template = self.response_template(LIST_ACCOUNT_ALIASES_TEMPLATE)
return template.render(aliases=aliases)
def create_account_alias(self):
alias = self._get_param('AccountAlias')
iam_backend.create_account_alias(alias)
template = self.response_template(CREATE_ACCOUNT_ALIAS_TEMPLATE)
return template.render()
def delete_account_alias(self):
alias = self._get_param('AccountAlias')
iam_backend.delete_account_alias(alias)
template = self.response_template(DELETE_ACCOUNT_ALIAS_TEMPLATE)
return template.render()
ATTACH_ROLE_POLICY_TEMPLATE = """<AttachRolePolicyResponse> ATTACH_ROLE_POLICY_TEMPLATE = """<AttachRolePolicyResponse>
<ResponseMetadata> <ResponseMetadata>
@ -429,6 +527,36 @@ ATTACH_ROLE_POLICY_TEMPLATE = """<AttachRolePolicyResponse>
</ResponseMetadata> </ResponseMetadata>
</AttachRolePolicyResponse>""" </AttachRolePolicyResponse>"""
DETACH_ROLE_POLICY_TEMPLATE = """<DetachRolePolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</DetachRolePolicyResponse>"""
ATTACH_USER_POLICY_TEMPLATE = """<AttachUserPolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</AttachUserPolicyResponse>"""
DETACH_USER_POLICY_TEMPLATE = """<DetachUserPolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</DetachUserPolicyResponse>"""
ATTACH_GROUP_POLICY_TEMPLATE = """<AttachGroupPolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</AttachGroupPolicyResponse>"""
DETACH_GROUP_POLICY_TEMPLATE = """<DetachGroupPolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</DetachGroupPolicyResponse>"""
CREATE_POLICY_TEMPLATE = """<CreatePolicyResponse> CREATE_POLICY_TEMPLATE = """<CreatePolicyResponse>
<CreatePolicyResult> <CreatePolicyResult>
<Policy> <Policy>
@ -469,6 +597,50 @@ LIST_ATTACHED_ROLE_POLICIES_TEMPLATE = """<ListAttachedRolePoliciesResponse>
</ResponseMetadata> </ResponseMetadata>
</ListAttachedRolePoliciesResponse>""" </ListAttachedRolePoliciesResponse>"""
LIST_ATTACHED_GROUP_POLICIES_TEMPLATE = """<ListAttachedGroupPoliciesResponse>
<ListAttachedGroupPoliciesResult>
{% if marker is none %}
<IsTruncated>false</IsTruncated>
{% else %}
<IsTruncated>true</IsTruncated>
<Marker>{{ marker }}</Marker>
{% endif %}
<AttachedPolicies>
{% for policy in policies %}
<member>
<PolicyName>{{ policy.name }}</PolicyName>
<PolicyArn>{{ policy.arn }}</PolicyArn>
</member>
{% endfor %}
</AttachedPolicies>
</ListAttachedGroupPoliciesResult>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</ListAttachedGroupPoliciesResponse>"""
LIST_ATTACHED_USER_POLICIES_TEMPLATE = """<ListAttachedUserPoliciesResponse>
<ListAttachedUserPoliciesResult>
{% if marker is none %}
<IsTruncated>false</IsTruncated>
{% else %}
<IsTruncated>true</IsTruncated>
<Marker>{{ marker }}</Marker>
{% endif %}
<AttachedPolicies>
{% for policy in policies %}
<member>
<PolicyName>{{ policy.name }}</PolicyName>
<PolicyArn>{{ policy.arn }}</PolicyArn>
</member>
{% endfor %}
</AttachedPolicies>
</ListAttachedUserPoliciesResult>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</ListAttachedUserPoliciesResponse>"""
LIST_POLICIES_TEMPLATE = """<ListPoliciesResponse> LIST_POLICIES_TEMPLATE = """<ListPoliciesResponse>
<ListPoliciesResult> <ListPoliciesResult>
{% if marker is none %} {% if marker is none %}
@ -918,12 +1090,11 @@ LIST_USERS_TEMPLATE = """<{{ action }}UsersResponse>
</ResponseMetadata> </ResponseMetadata>
</{{ action }}UsersResponse>""" </{{ action }}UsersResponse>"""
CREATE_LOGIN_PROFILE_TEMPLATE = """ CREATE_LOGIN_PROFILE_TEMPLATE = """<CreateLoginProfileResponse>
<CreateLoginProfileResponse>
<CreateLoginProfileResult> <CreateLoginProfileResult>
<LoginProfile> <LoginProfile>
<UserName>{{ user_name }}</UserName> <UserName>{{ user.name }}</UserName>
<CreateDate>2011-09-19T23:00:56Z</CreateDate> <CreateDate>{{ user.created_iso_8601 }}</CreateDate>
</LoginProfile> </LoginProfile>
</CreateLoginProfileResult> </CreateLoginProfileResult>
<ResponseMetadata> <ResponseMetadata>
@ -932,6 +1103,29 @@ CREATE_LOGIN_PROFILE_TEMPLATE = """
</CreateLoginProfileResponse> </CreateLoginProfileResponse>
""" """
GET_LOGIN_PROFILE_TEMPLATE = """<GetLoginProfileResponse>
<GetLoginProfileResult>
<LoginProfile>
<UserName>{{ user.name }}</UserName>
<CreateDate>{{ user.created_iso_8601 }}</CreateDate>
{% if user.password_reset_required %}
<PasswordResetRequired>true</PasswordResetRequired>
{% endif %}
</LoginProfile>
</GetLoginProfileResult>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</GetLoginProfileResponse>
"""
UPDATE_LOGIN_PROFILE_TEMPLATE = """<UpdateLoginProfileResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</UpdateLoginProfileResponse>
"""
GET_USER_POLICY_TEMPLATE = """<GetUserPolicyResponse> GET_USER_POLICY_TEMPLATE = """<GetUserPolicyResponse>
<GetUserPolicyResult> <GetUserPolicyResult>
<UserName>{{ user_name }}</UserName> <UserName>{{ user_name }}</UserName>
@ -965,9 +1159,7 @@ CREATE_ACCESS_KEY_TEMPLATE = """<CreateAccessKeyResponse>
<UserName>{{ key.user_name }}</UserName> <UserName>{{ key.user_name }}</UserName>
<AccessKeyId>{{ key.access_key_id }}</AccessKeyId> <AccessKeyId>{{ key.access_key_id }}</AccessKeyId>
<Status>{{ key.status }}</Status> <Status>{{ key.status }}</Status>
<SecretAccessKey> <SecretAccessKey>{{ key.secret_access_key }}</SecretAccessKey>
{{ key.secret_access_key }}
</SecretAccessKey>
</AccessKey> </AccessKey>
</CreateAccessKeyResult> </CreateAccessKeyResult>
<ResponseMetadata> <ResponseMetadata>
@ -1074,3 +1266,32 @@ LIST_MFA_DEVICES_TEMPLATE = """<ListMFADevicesResponse>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId> <RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata> </ResponseMetadata>
</ListMFADevicesResponse>""" </ListMFADevicesResponse>"""
LIST_ACCOUNT_ALIASES_TEMPLATE = """<ListAccountAliasesResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<ListAccountAliasesResult>
<IsTruncated>false</IsTruncated>
<AccountAliases>
{% for alias in aliases %}
<member>{{ alias }}</member>
{% endfor %}
</AccountAliases>
</ListAccountAliasesResult>
<ResponseMetadata>
<RequestId>c5a076e9-f1b0-11df-8fbe-45274EXAMPLE</RequestId>
</ResponseMetadata>
</ListAccountAliasesResponse>"""
CREATE_ACCOUNT_ALIAS_TEMPLATE = """<CreateAccountAliasResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<ResponseMetadata>
<RequestId>36b5db08-f1b0-11df-8fbe-45274EXAMPLE</RequestId>
</ResponseMetadata>
</CreateAccountAliasResponse>"""
DELETE_ACCOUNT_ALIAS_TEMPLATE = """<DeleteAccountAliasResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</DeleteAccountAliasResponse>"""

View File

@ -172,6 +172,13 @@ class Stream(BaseModel):
} }
} }
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
region = properties.get('Region', 'us-east-1')
shard_count = properties.get('ShardCount', 1)
return Stream(properties['Name'], shard_count, region)
class FirehoseRecord(BaseModel): class FirehoseRecord(BaseModel):

5
moto/logs/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .models import logs_backends
from ..core.models import base_decorator, deprecated_base_decorator
mock_logs = base_decorator(logs_backends)
mock_logs_deprecated = deprecated_base_decorator(logs_backends)

242
moto/logs/models.py Normal file
View File

@ -0,0 +1,242 @@
from moto.core import BaseBackend
import boto.logs
from moto.core.utils import unix_time_millis
class LogEvent:
_event_id = 0
def __init__(self, ingestion_time, log_event):
self.ingestionTime = ingestion_time
self.timestamp = log_event["timestamp"]
self.message = log_event['message']
self.eventId = self.__class__._event_id
self.__class__._event_id += 1
def to_filter_dict(self):
return {
"eventId": self.eventId,
"ingestionTime": self.ingestionTime,
# "logStreamName":
"message": self.message,
"timestamp": self.timestamp
}
def to_response_dict(self):
return {
"ingestionTime": self.ingestionTime,
"message": self.message,
"timestamp": self.timestamp
}
class LogStream:
_log_ids = 0
def __init__(self, region, log_group, name):
self.region = region
self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format(
region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name)
self.creationTime = unix_time_millis()
self.firstEventTimestamp = None
self.lastEventTimestamp = None
self.lastIngestionTime = None
self.logStreamName = name
self.storedBytes = 0
self.uploadSequenceToken = 0 # I'm guessing this is token needed for sequenceToken by put_events
self.events = []
self.__class__._log_ids += 1
def _update(self):
self.firstEventTimestamp = min([x.timestamp for x in self.events])
self.lastEventTimestamp = max([x.timestamp for x in self.events])
def to_describe_dict(self):
# Compute start and end times
self._update()
return {
"arn": self.arn,
"creationTime": self.creationTime,
"firstEventTimestamp": self.firstEventTimestamp,
"lastEventTimestamp": self.lastEventTimestamp,
"lastIngestionTime": self.lastIngestionTime,
"logStreamName": self.logStreamName,
"storedBytes": self.storedBytes,
"uploadSequenceToken": str(self.uploadSequenceToken),
}
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
# TODO: ensure sequence_token
# TODO: to be thread safe this would need a lock
self.lastIngestionTime = unix_time_millis()
# TODO: make this match AWS if possible
self.storedBytes += sum([len(log_event["message"]) for log_event in log_events])
self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events]
self.uploadSequenceToken += 1
return self.uploadSequenceToken
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
def filter_func(event):
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head)
back_token = next_token
if next_token is None:
next_token = 0
events_page = [event.to_response_dict() for event in events[next_token: next_token + limit]]
next_token += limit
if next_token >= len(self.events):
next_token = None
return events_page, back_token, next_token
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
def filter_func(event):
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
events = []
for event in sorted(filter(filter_func, self.events), key=lambda x: x.timestamp):
event_obj = event.to_filter_dict()
event_obj['logStreamName'] = self.logStreamName
events.append(event_obj)
return events
class LogGroup:
def __init__(self, region, name, tags):
self.name = name
self.region = region
self.tags = tags
self.streams = dict() # {name: LogStream}
def create_log_stream(self, log_stream_name):
assert log_stream_name not in self.streams
self.streams[log_stream_name] = LogStream(self.region, self.name, log_stream_name)
def delete_log_stream(self, log_stream_name):
assert log_stream_name in self.streams
del self.streams[log_stream_name]
def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by):
log_streams = [(name, stream.to_describe_dict()) for name, stream in self.streams.items() if name.startswith(log_stream_name_prefix)]
def sorter(item):
return item[0] if order_by == 'logStreamName' else item[1]['lastEventTimestamp']
if next_token is None:
next_token = 0
log_streams = sorted(log_streams, key=sorter, reverse=descending)
new_token = next_token + limit
log_streams_page = [x[1] for x in log_streams[next_token: new_token]]
if new_token >= len(log_streams):
new_token = None
return log_streams_page, new_token
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
assert log_stream_name in self.streams
stream = self.streams[log_stream_name]
return stream.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
assert log_stream_name in self.streams
stream = self.streams[log_stream_name]
return stream.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
assert not filter_pattern # TODO: impl
streams = [stream for name, stream in self.streams.items() if not log_stream_names or name in log_stream_names]
events = []
for stream in streams:
events += stream.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
if interleaved:
events = sorted(events, key=lambda event: event.timestamp)
if next_token is None:
next_token = 0
events_page = events[next_token: next_token + limit]
next_token += limit
if next_token >= len(events):
next_token = None
searched_streams = [{"logStreamName": stream.logStreamName, "searchedCompletely": True} for stream in streams]
return events_page, next_token, searched_streams
class LogsBackend(BaseBackend):
def __init__(self, region_name):
self.region_name = region_name
self.groups = dict() # { logGroupName: LogGroup}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_log_group(self, log_group_name, tags):
assert log_group_name not in self.groups
self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags)
def ensure_log_group(self, log_group_name, tags):
if log_group_name in self.groups:
return
self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags)
def delete_log_group(self, log_group_name):
assert log_group_name in self.groups
del self.groups[log_group_name]
def create_log_stream(self, log_group_name, log_stream_name):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.create_log_stream(log_stream_name)
def delete_log_stream(self, log_group_name, log_stream_name):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.delete_log_stream(log_stream_name)
def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.describe_log_streams(descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by)
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
# TODO: add support for sequence_tokens
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
logs_backends = {region.name: LogsBackend(region.name) for region in boto.logs.regions()}

114
moto/logs/responses.py Normal file
View File

@ -0,0 +1,114 @@
from moto.core.responses import BaseResponse
from .models import logs_backends
import json
# See http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/Welcome.html
class LogsResponse(BaseResponse):
@property
def logs_backend(self):
return logs_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)
def create_log_group(self):
log_group_name = self._get_param('logGroupName')
tags = self._get_param('tags')
assert 1 <= len(log_group_name) <= 512 # TODO: assert pattern
self.logs_backend.create_log_group(log_group_name, tags)
return ''
def delete_log_group(self):
log_group_name = self._get_param('logGroupName')
self.logs_backend.delete_log_group(log_group_name)
return ''
def create_log_stream(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
self.logs_backend.create_log_stream(log_group_name, log_stream_name)
return ''
def delete_log_stream(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
self.logs_backend.delete_log_stream(log_group_name, log_stream_name)
return ''
def describe_log_streams(self):
log_group_name = self._get_param('logGroupName')
log_stream_name_prefix = self._get_param('logStreamNamePrefix', '')
descending = self._get_param('descending', False)
limit = self._get_param('limit', 50)
assert limit <= 50
next_token = self._get_param('nextToken')
order_by = self._get_param('orderBy', 'LogStreamName')
assert order_by in {'LogStreamName', 'LastEventTime'}
if order_by == 'LastEventTime':
assert not log_stream_name_prefix
streams, next_token = self.logs_backend.describe_log_streams(
descending, limit, log_group_name, log_stream_name_prefix,
next_token, order_by)
return json.dumps({
"logStreams": streams,
"nextToken": next_token
})
def put_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
log_events = self._get_param('logEvents')
sequence_token = self._get_param('sequenceToken')
next_sequence_token = self.logs_backend.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
return json.dumps({'nextSequenceToken': next_sequence_token})
def get_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
start_time = self._get_param('startTime')
end_time = self._get_param("endTime")
limit = self._get_param('limit', 10000)
assert limit <= 10000
next_token = self._get_param('nextToken')
start_from_head = self._get_param('startFromHead', False)
events, next_backward_token, next_foward_token = \
self.logs_backend.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
return json.dumps({
"events": [ob.__dict__ for ob in events],
"nextBackwardToken": next_backward_token,
"nextForwardToken": next_foward_token
})
def filter_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_names = self._get_param('logStreamNames', [])
start_time = self._get_param('startTime')
# impl, see: http://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html
filter_pattern = self._get_param('filterPattern')
interleaved = self._get_param('interleaved', False)
end_time = self._get_param("endTime")
limit = self._get_param('limit', 10000)
assert limit <= 10000
next_token = self._get_param('nextToken')
events, next_token, searched_streams = self.logs_backend.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
return json.dumps({
"events": events,
"nextToken": next_token,
"searchedLogStreams": searched_streams
})

9
moto/logs/urls.py Normal file
View File

@ -0,0 +1,9 @@
from .responses import LogsResponse
url_bases = [
"https?://logs.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': LogsResponse.dispatch,
}

View File

@ -422,11 +422,11 @@ class OpsWorksBackend(BaseBackend):
stackid = kwargs['stack_id'] stackid = kwargs['stack_id']
if stackid not in self.stacks: if stackid not in self.stacks:
raise ResourceNotFoundException(stackid) raise ResourceNotFoundException(stackid)
if name in [l.name for l in self.layers.values()]: if name in [l.name for l in self.stacks[stackid].layers]:
raise ValidationException( raise ValidationException(
'There is already a layer named "{0}" ' 'There is already a layer named "{0}" '
'for this stack'.format(name)) 'for this stack'.format(name))
if shortname in [l.shortname for l in self.layers.values()]: if shortname in [l.shortname for l in self.stacks[stackid].layers]:
raise ValidationException( raise ValidationException(
'There is already a layer with shortname "{0}" ' 'There is already a layer with shortname "{0}" '
'for this stack'.format(shortname)) 'for this stack'.format(shortname))

View File

@ -72,6 +72,10 @@ from datetime import datetime
from datetime import timedelta from datetime import timedelta
from errno import EAGAIN from errno import EAGAIN
# Some versions of python internally shadowed the
# SocketType variable incorrectly https://bugs.python.org/issue20386
BAD_SOCKET_SHADOW = socket.socket != socket.SocketType
old_socket = socket.socket old_socket = socket.socket
old_create_connection = socket.create_connection old_create_connection = socket.create_connection
old_gethostbyname = socket.gethostbyname old_gethostbyname = socket.gethostbyname
@ -99,6 +103,12 @@ try: # pragma: no cover
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
ssl = None ssl = None
try: # pragma: no cover
from requests.packages.urllib3.contrib.pyopenssl import inject_into_urllib3, extract_from_urllib3
pyopenssl_override = True
except:
pyopenssl_override = False
DEFAULT_HTTP_PORTS = frozenset([80]) DEFAULT_HTTP_PORTS = frozenset([80])
POTENTIAL_HTTP_PORTS = set(DEFAULT_HTTP_PORTS) POTENTIAL_HTTP_PORTS = set(DEFAULT_HTTP_PORTS)
@ -976,7 +986,8 @@ class httpretty(HttpBaseClass):
def disable(cls): def disable(cls):
cls._is_enabled = False cls._is_enabled = False
socket.socket = old_socket socket.socket = old_socket
socket.SocketType = old_socket if not BAD_SOCKET_SHADOW:
socket.SocketType = old_socket
socket._socketobject = old_socket socket._socketobject = old_socket
socket.create_connection = old_create_connection socket.create_connection = old_create_connection
@ -986,7 +997,8 @@ class httpretty(HttpBaseClass):
socket.__dict__['socket'] = old_socket socket.__dict__['socket'] = old_socket
socket.__dict__['_socketobject'] = old_socket socket.__dict__['_socketobject'] = old_socket
socket.__dict__['SocketType'] = old_socket if not BAD_SOCKET_SHADOW:
socket.__dict__['SocketType'] = old_socket
socket.__dict__['create_connection'] = old_create_connection socket.__dict__['create_connection'] = old_create_connection
socket.__dict__['gethostname'] = old_gethostname socket.__dict__['gethostname'] = old_gethostname
@ -1007,6 +1019,9 @@ class httpretty(HttpBaseClass):
ssl.sslwrap_simple = old_sslwrap_simple ssl.sslwrap_simple = old_sslwrap_simple
ssl.__dict__['sslwrap_simple'] = old_sslwrap_simple ssl.__dict__['sslwrap_simple'] = old_sslwrap_simple
if pyopenssl_override:
inject_into_urllib3()
@classmethod @classmethod
def is_enabled(cls): def is_enabled(cls):
return cls._is_enabled return cls._is_enabled
@ -1014,13 +1029,10 @@ class httpretty(HttpBaseClass):
@classmethod @classmethod
def enable(cls): def enable(cls):
cls._is_enabled = True cls._is_enabled = True
# Some versions of python internally shadowed the
# SocketType variable incorrectly https://bugs.python.org/issue20386
bad_socket_shadow = (socket.socket != socket.SocketType)
socket.socket = fakesock.socket socket.socket = fakesock.socket
socket._socketobject = fakesock.socket socket._socketobject = fakesock.socket
if not bad_socket_shadow: if not BAD_SOCKET_SHADOW:
socket.SocketType = fakesock.socket socket.SocketType = fakesock.socket
socket.create_connection = create_fake_connection socket.create_connection = create_fake_connection
@ -1030,7 +1042,7 @@ class httpretty(HttpBaseClass):
socket.__dict__['socket'] = fakesock.socket socket.__dict__['socket'] = fakesock.socket
socket.__dict__['_socketobject'] = fakesock.socket socket.__dict__['_socketobject'] = fakesock.socket
if not bad_socket_shadow: if not BAD_SOCKET_SHADOW:
socket.__dict__['SocketType'] = fakesock.socket socket.__dict__['SocketType'] = fakesock.socket
socket.__dict__['create_connection'] = create_fake_connection socket.__dict__['create_connection'] = create_fake_connection
@ -1053,6 +1065,9 @@ class httpretty(HttpBaseClass):
ssl.sslwrap_simple = fake_wrap_socket ssl.sslwrap_simple = fake_wrap_socket
ssl.__dict__['sslwrap_simple'] = fake_wrap_socket ssl.__dict__['sslwrap_simple'] = fake_wrap_socket
if pyopenssl_override:
extract_from_urllib3()
def httprettified(test): def httprettified(test):
"A decorator tests that use HTTPretty" "A decorator tests that use HTTPretty"

View File

@ -10,6 +10,7 @@ import six
from collections import namedtuple, Sequence, Sized from collections import namedtuple, Sequence, Sized
from functools import update_wrapper from functools import update_wrapper
from cookies import Cookies from cookies import Cookies
from requests.adapters import HTTPAdapter
from requests.utils import cookiejar_from_dict from requests.utils import cookiejar_from_dict
from requests.exceptions import ConnectionError from requests.exceptions import ConnectionError
from requests.sessions import REDIRECT_STATI from requests.sessions import REDIRECT_STATI
@ -120,10 +121,12 @@ class RequestsMock(object):
POST = 'POST' POST = 'POST'
PUT = 'PUT' PUT = 'PUT'
def __init__(self, assert_all_requests_are_fired=True): def __init__(self, assert_all_requests_are_fired=True, pass_through=True):
self._calls = CallList() self._calls = CallList()
self.reset() self.reset()
self.assert_all_requests_are_fired = assert_all_requests_are_fired self.assert_all_requests_are_fired = assert_all_requests_are_fired
self.pass_through = pass_through
self.original_send = HTTPAdapter.send
def reset(self): def reset(self):
self._urls = [] self._urls = []
@ -235,6 +238,9 @@ class RequestsMock(object):
match = self._find_match(request) match = self._find_match(request)
# TODO(dcramer): find the correct class for this # TODO(dcramer): find the correct class for this
if match is None: if match is None:
if self.pass_through:
return self.original_send(adapter, request, **kwargs)
error_msg = 'Connection refused: {0} {1}'.format(request.method, error_msg = 'Connection refused: {0} {1}'.format(request.method,
request.url) request.url)
response = ConnectionError(error_msg) response = ConnectionError(error_msg)
@ -270,6 +276,8 @@ class RequestsMock(object):
body=body, body=body,
headers=headers, headers=headers,
preload_content=False, preload_content=False,
# Need to not decode_content to mimic requests
decode_content=False,
) )
response = adapter.build_response(request, response) response = adapter.build_response(request, response)
@ -315,7 +323,7 @@ class RequestsMock(object):
# expose default mock namespace # expose default mock namespace
mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False) mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False, pass_through=False)
__all__ = [] __all__ = []
for __attr in (a for a in dir(_default_mock) if not a.startswith('_')): for __attr in (a for a in dir(_default_mock) if not a.startswith('_')):
__all__.append(__attr) __all__.append(__attr)

6
moto/polly/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import polly_backends
from ..core.models import base_decorator
polly_backend = polly_backends['us-east-1']
mock_polly = base_decorator(polly_backends)

114
moto/polly/models.py Normal file
View File

@ -0,0 +1,114 @@
from __future__ import unicode_literals
from xml.etree import ElementTree as ET
import datetime
import boto3
from moto.core import BaseBackend, BaseModel
from .resources import VOICE_DATA
from .utils import make_arn_for_lexicon
DEFAULT_ACCOUNT_ID = 123456789012
class Lexicon(BaseModel):
def __init__(self, name, content, region_name):
self.name = name
self.content = content
self.size = 0
self.alphabet = None
self.last_modified = None
self.language_code = None
self.lexemes_count = 0
self.arn = make_arn_for_lexicon(DEFAULT_ACCOUNT_ID, name, region_name)
self.update()
def update(self, content=None):
if content is not None:
self.content = content
# Probably a very naive approach, but it'll do for now.
try:
root = ET.fromstring(self.content)
self.size = len(self.content)
self.last_modified = int((datetime.datetime.now() -
datetime.datetime(1970, 1, 1)).total_seconds())
self.lexemes_count = len(root.findall('.'))
for key, value in root.attrib.items():
if key.endswith('alphabet'):
self.alphabet = value
elif key.endswith('lang'):
self.language_code = value
except Exception as err:
raise ValueError('Failure parsing XML: {0}'.format(err))
def to_dict(self):
return {
'Attributes': {
'Alphabet': self.alphabet,
'LanguageCode': self.language_code,
'LastModified': self.last_modified,
'LexemesCount': self.lexemes_count,
'LexiconArn': self.arn,
'Size': self.size
}
}
def __repr__(self):
return '<Lexicon {0}>'.format(self.name)
class PollyBackend(BaseBackend):
def __init__(self, region_name=None):
super(PollyBackend, self).__init__()
self.region_name = region_name
self._lexicons = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def describe_voices(self, language_code, next_token):
if language_code is None:
return VOICE_DATA
return [item for item in VOICE_DATA if item['LanguageCode'] == language_code]
def delete_lexicon(self, name):
# implement here
del self._lexicons[name]
def get_lexicon(self, name):
# Raises KeyError
return self._lexicons[name]
def list_lexicons(self, next_token):
result = []
for name, lexicon in self._lexicons.items():
lexicon_dict = lexicon.to_dict()
lexicon_dict['Name'] = name
result.append(lexicon_dict)
return result
def put_lexicon(self, name, content):
# If lexicon content is bad, it will raise ValueError
if name in self._lexicons:
# Regenerated all the stats from the XML
# but keeps the ARN
self._lexicons.update(content)
else:
lexicon = Lexicon(name, content, region_name=self.region_name)
self._lexicons[name] = lexicon
available_regions = boto3.session.Session().get_available_regions("polly")
polly_backends = {region: PollyBackend(region_name=region) for region in available_regions}

63
moto/polly/resources.py Normal file
View File

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
VOICE_DATA = [
{'Id': 'Joanna', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Joanna'},
{'Id': 'Mizuki', 'LanguageCode': 'ja-JP', 'LanguageName': 'Japanese', 'Gender': 'Female', 'Name': 'Mizuki'},
{'Id': 'Filiz', 'LanguageCode': 'tr-TR', 'LanguageName': 'Turkish', 'Gender': 'Female', 'Name': 'Filiz'},
{'Id': 'Astrid', 'LanguageCode': 'sv-SE', 'LanguageName': 'Swedish', 'Gender': 'Female', 'Name': 'Astrid'},
{'Id': 'Tatyana', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Female', 'Name': 'Tatyana'},
{'Id': 'Maxim', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Male', 'Name': 'Maxim'},
{'Id': 'Carmen', 'LanguageCode': 'ro-RO', 'LanguageName': 'Romanian', 'Gender': 'Female', 'Name': 'Carmen'},
{'Id': 'Ines', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Female', 'Name': 'Inês'},
{'Id': 'Cristiano', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Male', 'Name': 'Cristiano'},
{'Id': 'Vitoria', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Female', 'Name': 'Vitória'},
{'Id': 'Ricardo', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Male', 'Name': 'Ricardo'},
{'Id': 'Maja', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Maja'},
{'Id': 'Jan', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jan'},
{'Id': 'Ewa', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Ewa'},
{'Id': 'Ruben', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Male', 'Name': 'Ruben'},
{'Id': 'Lotte', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Female', 'Name': 'Lotte'},
{'Id': 'Liv', 'LanguageCode': 'nb-NO', 'LanguageName': 'Norwegian', 'Gender': 'Female', 'Name': 'Liv'},
{'Id': 'Giorgio', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Male', 'Name': 'Giorgio'},
{'Id': 'Carla', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Female', 'Name': 'Carla'},
{'Id': 'Karl', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Male', 'Name': 'Karl'},
{'Id': 'Dora', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Female', 'Name': 'Dóra'},
{'Id': 'Mathieu', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Male', 'Name': 'Mathieu'},
{'Id': 'Celine', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Female', 'Name': 'Céline'},
{'Id': 'Chantal', 'LanguageCode': 'fr-CA', 'LanguageName': 'Canadian French', 'Gender': 'Female', 'Name': 'Chantal'},
{'Id': 'Penelope', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Female', 'Name': 'Penélope'},
{'Id': 'Miguel', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Male', 'Name': 'Miguel'},
{'Id': 'Enrique', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Male', 'Name': 'Enrique'},
{'Id': 'Conchita', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Female', 'Name': 'Conchita'},
{'Id': 'Geraint', 'LanguageCode': 'en-GB-WLS', 'LanguageName': 'Welsh English', 'Gender': 'Male', 'Name': 'Geraint'},
{'Id': 'Salli', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Salli'},
{'Id': 'Kimberly', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kimberly'},
{'Id': 'Kendra', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kendra'},
{'Id': 'Justin', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Justin'},
{'Id': 'Joey', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Joey'},
{'Id': 'Ivy', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Ivy'},
{'Id': 'Raveena', 'LanguageCode': 'en-IN', 'LanguageName': 'Indian English', 'Gender': 'Female', 'Name': 'Raveena'},
{'Id': 'Emma', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Emma'},
{'Id': 'Brian', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Male', 'Name': 'Brian'},
{'Id': 'Amy', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Amy'},
{'Id': 'Russell', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Male', 'Name': 'Russell'},
{'Id': 'Nicole', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Female', 'Name': 'Nicole'},
{'Id': 'Vicki', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Vicki'},
{'Id': 'Marlene', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Marlene'},
{'Id': 'Hans', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Male', 'Name': 'Hans'},
{'Id': 'Naja', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Female', 'Name': 'Naja'},
{'Id': 'Mads', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Male', 'Name': 'Mads'},
{'Id': 'Gwyneth', 'LanguageCode': 'cy-GB', 'LanguageName': 'Welsh', 'Gender': 'Female', 'Name': 'Gwyneth'},
{'Id': 'Jacek', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jacek'}
]
# {...} is also shorthand set syntax
LANGUAGE_CODES = {'cy-GB', 'da-DK', 'de-DE', 'en-AU', 'en-GB', 'en-GB-WLS', 'en-IN', 'en-US', 'es-ES', 'es-US',
'fr-CA', 'fr-FR', 'is-IS', 'it-IT', 'ja-JP', 'nb-NO', 'nl-NL', 'pl-PL', 'pt-BR', 'pt-PT', 'ro-RO',
'ru-RU', 'sv-SE', 'tr-TR'}
VOICE_IDS = {'Geraint', 'Gwyneth', 'Mads', 'Naja', 'Hans', 'Marlene', 'Nicole', 'Russell', 'Amy', 'Brian', 'Emma',
'Raveena', 'Ivy', 'Joanna', 'Joey', 'Justin', 'Kendra', 'Kimberly', 'Salli', 'Conchita', 'Enrique',
'Miguel', 'Penelope', 'Chantal', 'Celine', 'Mathieu', 'Dora', 'Karl', 'Carla', 'Giorgio', 'Mizuki',
'Liv', 'Lotte', 'Ruben', 'Ewa', 'Jacek', 'Jan', 'Maja', 'Ricardo', 'Vitoria', 'Cristiano', 'Ines',
'Carmen', 'Maxim', 'Tatyana', 'Astrid', 'Filiz'}

Some files were not shown because too many files have changed in this diff Show More