Merge remote-tracking branch 'spulec/master'

This commit is contained in:
Alexander Mohr 2017-10-02 12:38:11 -07:00
commit b9bf33209c
161 changed files with 23703 additions and 1379 deletions

3
.gitignore vendored
View File

@ -11,4 +11,5 @@ build/
.idea/
*.swp
.DS_Store
python_env
python_env
.ropeproject/

View File

@ -1,23 +1,36 @@
language: python
sudo: false
services:
- docker
python:
- 2.7
- 3.6
env:
- TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true
before_install:
- export BOTO_CONFIG=/dev/null
install:
- travis_retry pip install boto==2.45.0
- travis_retry pip install boto3
- travis_retry pip install .
- travis_retry pip install -r requirements-dev.txt
- travis_retry pip install coveralls
# We build moto first so the docker container doesn't try to compile it as well, also note we don't use
# -d for docker run so the logs show up in travis
# Python images come from here: https://hub.docker.com/_/python/
- |
python setup.py sdist
if [ "$TEST_SERVER_MODE" = "true" ]; then
AWS_SECRET_ACCESS_KEY=server_secret AWS_ACCESS_KEY_ID=server_key moto_server -p 5000&
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh &
export AWS_SECRET_ACCESS_KEY=foobar_secret
export AWS_ACCESS_KEY_ID=foobar_key
fi
travis_retry pip install boto==2.45.0
travis_retry pip install boto3
travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt
if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py
fi
script:
- make test
after_success:

View File

@ -3,6 +3,121 @@ Moto Changelog
Latest
------
1.1.20
-----
* Improved `make scaffold`
* Implemented IAM attached group policies
* Implemented skeleton of Cloudwatch Logs
* Redshift: fixed multi-params
* Redshift: implement taggable resources
* Lambda + SNS: Major enhancements
1.1.19
-----
* Fixing regression from 1.1.15
1.1.15
-----
* Polly implementation
* Added EC2 instance info
* SNS publish by phone number
1.1.14
-----
* ACM implementation
* Added `make scaffold`
* X-Ray implementation
1.1.13
-----
* Created alpine-based Dockerfile (dockerhub: motoserver/moto)
* SNS.SetSMSAttributes & SNS.GetSMSAttributes + Filtering
* S3 ACL implementation
* pushing to Dockerhub on `make publish`
1.1.12
-----
* implemented all AWS managed policies in source
* fixing Dynamodb CapacityUnits format
* S3 ACL implementation
1.1.11
-----
* S3 authentication
* SSM get_parameter
* ELBv2 target group tagging
* EC2 Security group filters
1.1.10
-----
* EC2 vpc address filtering
* EC2 elastic ip dissociation
* ELBv2 target group tagging
* fixed complexity of accepting new filter implementations
1.1.9
-----
* EC2 root device mapping
1.1.8
-----
* Lambda get_function for function created with zipfile
* scripts/implementation_coverage.py
1.1.7
-----
* Lambda invoke_async
* EC2 keypair filtering
1.1.6
-----
* Dynamo ADD and DELETE operations in update expressions
* Lambda tag support
1.1.5
-----
* Dynamo allow ADD update_item of a string set
* Handle max-keys in list-objects
* bugfixes in pagination
1.1.3
-----
* EC2 vpc_id in responses
1.1.2
-----
* IAM account aliases
* SNS subscription attributes
* bugfixes in Dynamo, CFN, and EC2
1.1.1
-----
* EC2 group-id filter
* EC2 list support for filters
1.1.0
-----
* Add ELBv2
* IAM user policies
* RDS snapshots
* IAM policy versions
1.0.1
-----

View File

@ -1,11 +1,22 @@
FROM python:2
FROM alpine:3.6
RUN apk add --no-cache --update \
gcc \
musl-dev \
python3-dev \
libffi-dev \
openssl-dev \
python3
ADD . /moto/
ENV PYTHONUNBUFFERED 1
WORKDIR /moto/
RUN pip install ".[server]"
RUN python3 -m ensurepip && \
rm -r /usr/lib/python*/ensurepip && \
pip3 --no-cache-dir install --upgrade pip setuptools && \
pip3 --no-cache-dir install ".[server]"
CMD ["moto_server"]
ENTRYPOINT ["/usr/bin/moto_server", "-H", "0.0.0.0"]
EXPOSE 5000

View File

@ -1,3 +1,4 @@
include README.md LICENSE AUTHORS.md
include requirements.txt requirements-dev.txt tox.ini
include moto/ec2/resources/instance_types.json
recursive-include tests *

View File

@ -15,5 +15,22 @@ test: lint
test_server:
@TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/
publish:
aws_managed_policies:
scripts/update_managed_policies.py
upload_pypi_artifact:
python setup.py sdist bdist_wheel upload
push_dockerhub_image:
docker build -t motoserver/moto .
docker push motoserver/moto
tag_github_release:
git tag `python setup.py --version`
git push origin `python setup.py --version`
publish: upload_pypi_artifact push_dockerhub_image tag_github_release
scaffold:
@pip install -r requirements-dev.txt > /dev/null
exec python scripts/scaffold.py

View File

@ -58,6 +58,8 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
|------------------------------------------------------------------------------|
| Service Name | Decorator | Development Status |
|------------------------------------------------------------------------------|
| ACM | @mock_acm | all endpoints done |
|------------------------------------------------------------------------------|
| API Gateway | @mock_apigateway | core endpoints done |
|------------------------------------------------------------------------------|
| Autoscaling | @mock_autoscaling| core endpoints done |
@ -78,10 +80,14 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
| - Security Groups | | core endpoints done |
| - Tags | | all endpoints done |
|------------------------------------------------------------------------------|
| ECR | @mock_ecr | basic endpoints done |
|------------------------------------------------------------------------------|
| ECS | @mock_ecs | basic endpoints done |
|------------------------------------------------------------------------------|
| ELB | @mock_elb | core endpoints done |
|------------------------------------------------------------------------------|
| ELBv2 | @mock_elbv2 | core endpoints done |
|------------------------------------------------------------------------------|
| EMR | @mock_emr | core endpoints done |
|------------------------------------------------------------------------------|
| Glacier | @mock_glacier | core endpoints done |
@ -90,10 +96,14 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
|------------------------------------------------------------------------------|
| Lambda | @mock_lambda | basic endpoints done |
|------------------------------------------------------------------------------|
| Logs | @mock_logs | basic endpoints done |
|------------------------------------------------------------------------------|
| Kinesis | @mock_kinesis | core endpoints done |
|------------------------------------------------------------------------------|
| KMS | @mock_kms | basic endpoints done |
|------------------------------------------------------------------------------|
| Polly | @mock_polly | all endpoints done |
|------------------------------------------------------------------------------|
| RDS | @mock_rds | core endpoints done |
|------------------------------------------------------------------------------|
| RDS2 | @mock_rds2 | core endpoints done |
@ -106,7 +116,7 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
|------------------------------------------------------------------------------|
| SES | @mock_ses | core endpoints done |
|------------------------------------------------------------------------------|
| SNS | @mock_sns | core endpoints done |
| SNS | @mock_sns | all endpoints done |
|------------------------------------------------------------------------------|
| SQS | @mock_sqs | core endpoints done |
|------------------------------------------------------------------------------|
@ -116,6 +126,8 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
|------------------------------------------------------------------------------|
| SWF | @mock_swf | basic endpoints done |
|------------------------------------------------------------------------------|
| X-Ray | @mock_xray | core endpoints done |
|------------------------------------------------------------------------------|
```
### Another Example
@ -153,9 +165,9 @@ moto 1.0.X mock docorators are defined for boto3 and do not work with boto2. Use
Using moto with boto2
```python
from moto import mock_ec2_deprecated
from moto import mock_ec2_deprecated
import boto
@mock_ec2_deprecated
def test_something_with_ec2():
ec2_conn = boto.ec2.connect_to_region('us-east-1')

View File

@ -5,6 +5,7 @@ import logging
__title__ = 'moto'
__version__ = '1.0.1'
from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # flake8: noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # flake8: noqa
@ -21,10 +22,11 @@ from .elbv2 import mock_elbv2 # flake8: noqa
from .emr import mock_emr, mock_emr_deprecated # flake8: noqa
from .events import mock_events # flake8: noqa
from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .iam import mock_iam, mock_iam_deprecated # flake8: noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa
from .kms import mock_kms, mock_kms_deprecated # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .polly import mock_polly # flake8: noqa
from .rds import mock_rds, mock_rds_deprecated # flake8: noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa
from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa
@ -36,6 +38,8 @@ from .sts import mock_sts, mock_sts_deprecated # flake8: noqa
from .ssm import mock_ssm # flake8: noqa
from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa
from .swf import mock_swf, mock_swf_deprecated # flake8: noqa
from .xray import mock_xray # flake8: noqa
from .logs import mock_logs, mock_logs_deprecated # flake8: noqa
try:

6
moto/acm/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import acm_backends
from ..core.models import base_decorator
acm_backend = acm_backends['us-east-1']
mock_acm = base_decorator(acm_backends)

395
moto/acm/models.py Normal file
View File

@ -0,0 +1,395 @@
from __future__ import unicode_literals
import re
import json
import datetime
from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends
from .utils import make_arn_for_certificate
import cryptography.x509
import cryptography.hazmat.primitives.asymmetric.rsa
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.backends import default_backend
DEFAULT_ACCOUNT_ID = 123456789012
GOOGLE_ROOT_CA = b"""-----BEGIN CERTIFICATE-----
MIIEKDCCAxCgAwIBAgIQAQAhJYiw+lmnd+8Fe2Yn3zANBgkqhkiG9w0BAQsFADBC
MQswCQYDVQQGEwJVUzEWMBQGA1UEChMNR2VvVHJ1c3QgSW5jLjEbMBkGA1UEAxMS
R2VvVHJ1c3QgR2xvYmFsIENBMB4XDTE3MDUyMjExMzIzN1oXDTE4MTIzMTIzNTk1
OVowSTELMAkGA1UEBhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMT
HEdvb2dsZSBJbnRlcm5ldCBBdXRob3JpdHkgRzIwggEiMA0GCSqGSIb3DQEBAQUA
A4IBDwAwggEKAoIBAQCcKgR3XNhQkToGo4Lg2FBIvIk/8RlwGohGfuCPxfGJziHu
Wv5hDbcyRImgdAtTT1WkzoJile7rWV/G4QWAEsRelD+8W0g49FP3JOb7kekVxM/0
Uw30SvyfVN59vqBrb4fA0FAfKDADQNoIc1Fsf/86PKc3Bo69SxEE630k3ub5/DFx
+5TVYPMuSq9C0svqxGoassxT3RVLix/IGWEfzZ2oPmMrhDVpZYTIGcVGIvhTlb7j
gEoQxirsupcgEcc5mRAEoPBhepUljE5SdeK27QjKFPzOImqzTs9GA5eXA37Asd57
r0Uzz7o+cbfe9CUlwg01iZ2d+w4ReYkeN8WvjnJpAgMBAAGjggERMIIBDTAfBgNV
HSMEGDAWgBTAephojYn7qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1
dvWBtrtiGrpagS8wDgYDVR0PAQH/BAQDAgEGMC4GCCsGAQUFBwEBBCIwIDAeBggr
BgEFBQcwAYYSaHR0cDovL2cuc3ltY2QuY29tMBIGA1UdEwEB/wQIMAYBAf8CAQAw
NQYDVR0fBC4wLDAqoCigJoYkaHR0cDovL2cuc3ltY2IuY29tL2NybHMvZ3RnbG9i
YWwuY3JsMCEGA1UdIAQaMBgwDAYKKwYBBAHWeQIFATAIBgZngQwBAgIwHQYDVR0l
BBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMA0GCSqGSIb3DQEBCwUAA4IBAQDKSeWs
12Rkd1u+cfrP9B4jx5ppY1Rf60zWGSgjZGaOHMeHgGRfBIsmr5jfCnC8vBk97nsz
qX+99AXUcLsFJnnqmseYuQcZZTTMPOk/xQH6bwx+23pwXEz+LQDwyr4tjrSogPsB
E4jLnD/lu3fKOmc2887VJwJyQ6C9bgLxRwVxPgFZ6RGeGvOED4Cmong1L7bHon8X
fOGLVq7uZ4hRJzBgpWJSwzfVO+qFKgE4h6LPcK2kesnE58rF2rwjMvL+GMJ74N87
L9TQEOaWTPtEtyFkDbkAlDASJodYmDkFOA/MgkgMCkdm7r+0X8T/cKjhf4t5K7hl
MqO5tzHpCvX2HzLc
-----END CERTIFICATE-----"""
# Added google root CA as AWS returns chain you gave it + root CA (provided or not)
# so for now a cheap response is just give any old root CA
def datetime_to_epoch(date):
# As only Py3 has datetime.timestamp()
return int((date - datetime.datetime(1970, 1, 1)).total_seconds())
class AWSError(Exception):
TYPE = None
STATUS = 400
def __init__(self, message):
self.message = message
def response(self):
resp = {'__type': self.TYPE, 'message': self.message}
return json.dumps(resp), dict(status=self.STATUS)
class AWSValidationException(AWSError):
TYPE = 'ValidationException'
class AWSResourceNotFoundException(AWSError):
TYPE = 'ResourceNotFoundException'
class CertBundle(BaseModel):
def __init__(self, certificate, private_key, chain=None, region='us-east-1', arn=None, cert_type='IMPORTED', cert_status='ISSUED'):
self.created_at = datetime.datetime.now()
self.cert = certificate
self._cert = None
self.common_name = None
self.key = private_key
self._key = None
self.chain = chain
self.tags = {}
self._chain = None
self.type = cert_type # Should really be an enum
self.status = cert_status # Should really be an enum
# AWS always returns your chain + root CA
if self.chain is None:
self.chain = GOOGLE_ROOT_CA
else:
self.chain += b'\n' + GOOGLE_ROOT_CA
# Takes care of PEM checking
self.validate_pk()
self.validate_certificate()
if chain is not None:
self.validate_chain()
# TODO check cert is valid, or if self-signed then a chain is provided, otherwise
# raise AWSValidationException('Provided certificate is not a valid self signed. Please provide either a valid self-signed certificate or certificate chain.')
# Used for when one wants to overwrite an arn
if arn is None:
self.arn = make_arn_for_certificate(DEFAULT_ACCOUNT_ID, region)
else:
self.arn = arn
@classmethod
def generate_cert(cls, domain_name, sans=None):
if sans is None:
sans = set()
else:
sans = set(sans)
sans.add(domain_name)
sans = [cryptography.x509.DNSName(item) for item in sans]
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend())
subject = cryptography.x509.Name([
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, u"CA"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.LOCALITY_NAME, u"San Francisco"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"My Company"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, domain_name),
])
issuer = cryptography.x509.Name([ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"Amazon"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u"Server CA 1B"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, u"Amazon"),
])
cert = cryptography.x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
key.public_key()
).serial_number(
cryptography.x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=365)
).add_extension(
cryptography.x509.SubjectAlternativeName(sans),
critical=False,
).sign(key, hashes.SHA512(), default_backend())
cert_armored = cert.public_bytes(serialization.Encoding.PEM)
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION')
def validate_pk(self):
try:
self._key = serialization.load_pem_private_key(self.key, password=None, backend=default_backend())
if self._key.key_size > 2048:
AWSValidationException('The private key length is not supported. Only 1024-bit and 2048-bit are allowed.')
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The private key is not PEM-encoded or is not valid.')
def validate_certificate(self):
try:
self._cert = cryptography.x509.load_pem_x509_certificate(self.cert, default_backend())
now = datetime.datetime.now()
if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate has expired, is not valid.')
if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate is not in effect yet, is not valid.')
# Extracting some common fields for ease of use
# Have to search through cert.subject for OIDs
self.common_name = self._cert.subject.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.')
def validate_chain(self):
try:
self._chain = []
for cert_armored in self.chain.split(b'-\n-'):
# Would leave encoded but Py2 does not have raw binary strings
cert_armored = cert_armored.decode()
# Fix missing -'s on split
cert_armored = re.sub(r'^----B', '-----B', cert_armored)
cert_armored = re.sub(r'E----$', 'E-----', cert_armored)
cert = cryptography.x509.load_pem_x509_certificate(cert_armored.encode(), default_backend())
self._chain.append(cert)
now = datetime.datetime.now()
if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate chain has expired, is not valid.')
if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate chain is not in effect yet, is not valid.')
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.')
def check(self):
# Basically, if the certificate is pending, and then checked again after 1 min
# It will appear as if its been validated
if self.type == 'AMAZON_ISSUED' and self.status == 'PENDING_VALIDATION' and \
(datetime.datetime.now() - self.created_at).total_seconds() > 60: # 1min
self.status = 'ISSUED'
def describe(self):
# 'RenewalSummary': {}, # Only when cert is amazon issued
if self._key.key_size == 1024:
key_algo = 'RSA_1024'
elif self._key.key_size == 2048:
key_algo = 'RSA_2048'
else:
key_algo = 'EC_prime256v1'
# Look for SANs
san_obj = self._cert.extensions.get_extension_for_oid(cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME)
sans = []
if san_obj is not None:
sans = [item.value for item in san_obj.value]
result = {
'Certificate': {
'CertificateArn': self.arn,
'DomainName': self.common_name,
'InUseBy': [],
'Issuer': self._cert.issuer.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value,
'KeyAlgorithm': key_algo,
'NotAfter': datetime_to_epoch(self._cert.not_valid_after),
'NotBefore': datetime_to_epoch(self._cert.not_valid_before),
'Serial': self._cert.serial,
'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''),
'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED.
'Subject': 'CN={0}'.format(self.common_name),
'SubjectAlternativeNames': sans,
'Type': self.type # One of IMPORTED, AMAZON_ISSUED
}
}
if self.type == 'IMPORTED':
result['Certificate']['ImportedAt'] = datetime_to_epoch(self.created_at)
else:
result['Certificate']['CreatedAt'] = datetime_to_epoch(self.created_at)
result['Certificate']['IssuedAt'] = datetime_to_epoch(self.created_at)
return result
def __str__(self):
return self.arn
def __repr__(self):
return '<Certificate>'
class AWSCertificateManagerBackend(BaseBackend):
def __init__(self, region):
super(AWSCertificateManagerBackend, self).__init__()
self.region = region
self._certificates = {}
self._idempotency_tokens = {}
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@staticmethod
def _arn_not_found(arn):
msg = 'Certificate with arn {0} not found in account {1}'.format(arn, DEFAULT_ACCOUNT_ID)
return AWSResourceNotFoundException(msg)
def _get_arn_from_idempotency_token(self, token):
"""
If token doesnt exist, return None, later it will be
set with an expiry and arn.
If token expiry has passed, delete entry and return None
Else return ARN
:param token: String token
:return: None or ARN
"""
now = datetime.datetime.now()
if token in self._idempotency_tokens:
if self._idempotency_tokens[token]['expires'] < now:
# Token has expired, new request
del self._idempotency_tokens[token]
return None
else:
return self._idempotency_tokens[token]['arn']
return None
def _set_idempotency_token_arn(self, token, arn):
self._idempotency_tokens[token] = {'arn': arn, 'expires': datetime.datetime.now() + datetime.timedelta(hours=1)}
def import_cert(self, certificate, private_key, chain=None, arn=None):
if arn is not None:
if arn not in self._certificates:
raise self._arn_not_found(arn)
else:
# Will reuse provided ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region, arn=arn)
else:
# Will generate a random ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region)
self._certificates[bundle.arn] = bundle
return bundle.arn
def get_certificates_list(self):
"""
Get list of certificates
:return: List of certificates
:rtype: list of CertBundle
"""
for arn in self._certificates.keys():
yield self.get_certificate(arn)
def get_certificate(self, arn):
if arn not in self._certificates:
raise self._arn_not_found(arn)
cert_bundle = self._certificates[arn]
cert_bundle.check()
return cert_bundle
def delete_certificate(self, arn):
if arn not in self._certificates:
raise self._arn_not_found(arn)
del self._certificates[arn]
def request_certificate(self, domain_name, domain_validation_options, idempotency_token, subject_alt_names):
if idempotency_token is not None:
arn = self._get_arn_from_idempotency_token(idempotency_token)
if arn is not None:
return arn
cert = CertBundle.generate_cert(domain_name, subject_alt_names)
if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert
return cert.arn
def add_tags_to_certificate(self, arn, tags):
# get_cert does arn check
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag['Key']
value = tag.get('Value', None)
cert_bundle.tags[key] = value
def remove_tags_from_certificate(self, arn, tags):
# get_cert does arn check
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag['Key']
value = tag.get('Value', None)
try:
# If value isnt provided, just delete key
if value is None:
del cert_bundle.tags[key]
# If value is provided, only delete if it matches what already exists
elif cert_bundle.tags[key] == value:
del cert_bundle.tags[key]
except KeyError:
pass
acm_backends = {}
for region, ec2_backend in ec2_backends.items():
acm_backends[region] = AWSCertificateManagerBackend(region)

224
moto/acm/responses.py Normal file
View File

@ -0,0 +1,224 @@
from __future__ import unicode_literals
import json
import base64
from moto.core.responses import BaseResponse
from .models import acm_backends, AWSError, AWSValidationException
class AWSCertificateManagerResponse(BaseResponse):
@property
def acm_backend(self):
"""
ACM Backend
:return: ACM Backend object
:rtype: moto.acm.models.AWSCertificateManagerBackend
"""
return acm_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param, default=None):
return self.request_params.get(param, default)
def add_tags_to_certificate(self):
arn = self._get_param('CertificateArn')
tags = self._get_param('Tags')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
self.acm_backend.add_tags_to_certificate(arn, tags)
except AWSError as err:
return err.response()
return ''
def delete_certificate(self):
arn = self._get_param('CertificateArn')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
self.acm_backend.delete_certificate(arn)
except AWSError as err:
return err.response()
return ''
def describe_certificate(self):
arn = self._get_param('CertificateArn')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
return json.dumps(cert_bundle.describe())
def get_certificate(self):
arn = self._get_param('CertificateArn')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
result = {
'Certificate': cert_bundle.cert.decode(),
'CertificateChain': cert_bundle.chain.decode()
}
return json.dumps(result)
def import_certificate(self):
"""
Returns errors on:
Certificate, PrivateKey or Chain not being properly formatted
Arn not existing if its provided
PrivateKey size > 2048
Certificate expired or is not yet in effect
Does not return errors on:
Checking Certificate is legit, or a selfsigned chain is provided
:return: str(JSON) for response
"""
certificate = self._get_param('Certificate')
private_key = self._get_param('PrivateKey')
chain = self._get_param('CertificateChain') # Optional
current_arn = self._get_param('CertificateArn') # Optional
# Simple parameter decoding. Rather do it here as its a data transport decision not part of the
# actual data
try:
certificate = base64.standard_b64decode(certificate)
except:
return AWSValidationException('The certificate is not PEM-encoded or is not valid.').response()
try:
private_key = base64.standard_b64decode(private_key)
except:
return AWSValidationException('The private key is not PEM-encoded or is not valid.').response()
if chain is not None:
try:
chain = base64.standard_b64decode(chain)
except:
return AWSValidationException('The certificate chain is not PEM-encoded or is not valid.').response()
try:
arn = self.acm_backend.import_cert(certificate, private_key, chain=chain, arn=current_arn)
except AWSError as err:
return err.response()
return json.dumps({'CertificateArn': arn})
def list_certificates(self):
certs = []
for cert_bundle in self.acm_backend.get_certificates_list():
certs.append({
'CertificateArn': cert_bundle.arn,
'DomainName': cert_bundle.common_name
})
result = {'CertificateSummaryList': certs}
return json.dumps(result)
def list_tags_for_certificate(self):
arn = self._get_param('CertificateArn')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return {'__type': 'MissingParameter', 'message': msg}, dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
result = {'Tags': []}
# Tag "objects" can not contain the Value part
for key, value in cert_bundle.tags.items():
tag_dict = {'Key': key}
if value is not None:
tag_dict['Value'] = value
result['Tags'].append(tag_dict)
return json.dumps(result)
def remove_tags_from_certificate(self):
arn = self._get_param('CertificateArn')
tags = self._get_param('Tags')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
self.acm_backend.remove_tags_from_certificate(arn, tags)
except AWSError as err:
return err.response()
return ''
def request_certificate(self):
domain_name = self._get_param('DomainName')
domain_validation_options = self._get_param('DomainValidationOptions') # is ignored atm
idempotency_token = self._get_param('IdempotencyToken')
subject_alt_names = self._get_param('SubjectAlternativeNames')
if len(subject_alt_names) > 10:
# There is initial AWS limit of 10
msg = 'An ACM limit has been exceeded. Need to request SAN limit to be raised'
return json.dumps({'__type': 'LimitExceededException', 'message': msg}), dict(status=400)
try:
arn = self.acm_backend.request_certificate(domain_name, domain_validation_options, idempotency_token, subject_alt_names)
except AWSError as err:
return err.response()
return json.dumps({'CertificateArn': arn})
def resend_validation_email(self):
arn = self._get_param('CertificateArn')
domain = self._get_param('Domain')
# ValidationDomain not used yet.
# Contains domain which is equal to or a subset of Domain
# that AWS will send validation emails to
# https://docs.aws.amazon.com/acm/latest/APIReference/API_ResendValidationEmail.html
# validation_domain = self._get_param('ValidationDomain')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
if cert_bundle.common_name != domain:
msg = 'Parameter Domain does not match certificate domain'
_type = 'InvalidDomainValidationOptionsException'
return json.dumps({'__type': _type, 'message': msg}), dict(status=400)
except AWSError as err:
return err.response()
return ''

10
moto/acm/urls.py Normal file
View File

@ -0,0 +1,10 @@
from __future__ import unicode_literals
from .responses import AWSCertificateManagerResponse
url_bases = [
"https?://acm.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': AWSCertificateManagerResponse.dispatch,
}

7
moto/acm/utils.py Normal file
View File

@ -0,0 +1,7 @@
import uuid
def make_arn_for_certificate(account_id, region_name):
# Example
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(region_name, account_id, uuid.uuid4())

View File

@ -1,34 +1,150 @@
from __future__ import unicode_literals
import base64
from collections import defaultdict
import datetime
import docker.errors
import hashlib
import io
import logging
import os
import json
import sys
import re
import zipfile
try:
from StringIO import StringIO
except:
from io import StringIO
import uuid
import functools
import tarfile
import calendar
import threading
import traceback
import requests.adapters
import boto.awslambda
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time_millis
from moto.s3.models import s3_backend
from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings
logger = logging.getLogger(__name__)
try:
from tempfile import TemporaryDirectory
except ImportError:
from backports.tempfile import TemporaryDirectory
_stderr_regex = re.compile(r'START|END|REPORT RequestId: .*')
_orig_adapter_send = requests.adapters.HTTPAdapter.send
def zip2tar(zip_bytes):
with TemporaryDirectory() as td:
tarname = os.path.join(td, 'data.tar')
timeshift = int((datetime.datetime.now() -
datetime.datetime.utcnow()).total_seconds())
with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zipf, \
tarfile.TarFile(tarname, 'w') as tarf:
for zipinfo in zipf.infolist():
if zipinfo.filename[-1] == '/': # is_dir() is py3.6+
continue
tarinfo = tarfile.TarInfo(name=zipinfo.filename)
tarinfo.size = zipinfo.file_size
tarinfo.mtime = calendar.timegm(zipinfo.date_time) - timeshift
infile = zipf.open(zipinfo.filename)
tarf.addfile(tarinfo, infile)
with open(tarname, 'rb') as f:
tar_data = f.read()
return tar_data
class _VolumeRefCount:
__slots__ = "refcount", "volume"
def __init__(self, refcount, volume):
self.refcount = refcount
self.volume = volume
class _DockerDataVolumeContext:
_data_vol_map = defaultdict(lambda: _VolumeRefCount(0, None)) # {sha256: _VolumeRefCount}
_lock = threading.Lock()
def __init__(self, lambda_func):
self._lambda_func = lambda_func
self._vol_ref = None
@property
def name(self):
return self._vol_ref.volume.name
def __enter__(self):
# See if volume is already known
with self.__class__._lock:
self._vol_ref = self.__class__._data_vol_map[self._lambda_func.code_sha_256]
self._vol_ref.refcount += 1
if self._vol_ref.refcount > 1:
return self
# See if the volume already exists
for vol in self._lambda_func.docker_client.volumes.list():
if vol.name == self._lambda_func.code_sha_256:
self._vol_ref.volume = vol
return self
# It doesn't exist so we need to create it
self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(self._lambda_func.code_sha_256)
container = self._lambda_func.docker_client.containers.run('alpine', 'sleep 100', volumes={self.name: '/tmp/data'}, detach=True)
try:
tar_bytes = zip2tar(self._lambda_func.code_bytes)
container.put_archive('/tmp/data', tar_bytes)
finally:
container.remove(force=True)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
with self.__class__._lock:
self._vol_ref.refcount -= 1
if self._vol_ref.refcount == 0:
try:
self._vol_ref.volume.remove()
except docker.errors.APIError as e:
if e.status_code != 409:
raise
raise # multiple processes trying to use same volume?
class LambdaFunction(BaseModel):
def __init__(self, spec, validate_s3=True):
def __init__(self, spec, region, validate_s3=True):
# required
self.region = region
self.code = spec['Code']
self.function_name = spec['FunctionName']
self.handler = spec['Handler']
self.role = spec['Role']
self.run_time = spec['Runtime']
self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get('Environment', {}).get('Variables', {})
self.docker_client = docker.from_env()
# Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked
if requests.adapters.HTTPAdapter.send != _orig_adapter_send:
_orig_get_adapter = self.docker_client.api.get_adapter
def replace_adapter_send(*args, **kwargs):
adapter = _orig_get_adapter(*args, **kwargs)
if isinstance(adapter, requests.adapters.HTTPAdapter):
adapter.send = functools.partial(_orig_adapter_send, adapter)
return adapter
self.docker_client.api.get_adapter = replace_adapter_send
# optional
self.description = spec.get('Description', '')
@ -36,13 +152,18 @@ class LambdaFunction(BaseModel):
self.publish = spec.get('Publish', False) # this is ignored currently
self.timeout = spec.get('Timeout', 3)
self.logs_group_name = '/aws/lambda/{}'.format(self.function_name)
self.logs_backend.ensure_log_group(self.logs_group_name, [])
# this isn't finished yet. it needs to find out the VpcId value
self._vpc_config = spec.get(
'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []})
# auto-generated
self.version = '$LATEST'
self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
self.last_modified = datetime.datetime.utcnow().strftime(
'%Y-%m-%d %H:%M:%S')
if 'ZipFile' in self.code:
# more hackery to handle unicode/bytes/str in python3 and python2 -
# argh!
@ -52,12 +173,13 @@ class LambdaFunction(BaseModel):
except Exception:
to_unzip_code = base64.b64decode(self.code['ZipFile'])
zbuffer = io.BytesIO()
zbuffer.write(to_unzip_code)
zip_file = zipfile.ZipFile(zbuffer, 'r', zipfile.ZIP_DEFLATED)
self.code = zip_file.read("".join(zip_file.namelist()))
self.code_bytes = to_unzip_code
self.code_size = len(to_unzip_code)
self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest()
# TODO: we should be putting this in a lambda bucket
self.code['UUID'] = str(uuid.uuid4())
self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID'])
else:
# validate s3 bucket and key
key = None
@ -76,10 +198,14 @@ class LambdaFunction(BaseModel):
"InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.")
if key:
self.code_bytes = key.value
self.code_size = key.size
self.code_sha_256 = hashlib.sha256(key.value).hexdigest()
self.function_arn = 'arn:aws:lambda:123456789012:function:{0}'.format(
self.function_name)
self.function_arn = 'arn:aws:lambda:{}:123456789012:function:{}'.format(
self.region, self.function_name)
self.tags = dict()
@property
def vpc_config(self):
@ -92,7 +218,7 @@ class LambdaFunction(BaseModel):
return json.dumps(self.get_configuration())
def get_configuration(self):
return {
config = {
"CodeSha256": self.code_sha_256,
"CodeSize": self.code_size,
"Description": self.description,
@ -108,65 +234,105 @@ class LambdaFunction(BaseModel):
"VpcConfig": self.vpc_config,
}
if self.environment_vars:
config['Environment'] = {
'Variables': self.environment_vars
}
return config
def get_code(self):
return {
"Code": {
"Location": "s3://lambda-functions.aws.amazon.com/{0}".format(self.code['S3Key']),
"Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(self.region, self.code['S3Key']),
"RepositoryType": "S3"
},
"Configuration": self.get_configuration(),
}
def convert(self, s):
@staticmethod
def convert(s):
try:
return str(s, encoding='utf-8')
except:
return s
def is_json(self, test_str):
@staticmethod
def is_json(test_str):
try:
response = json.loads(test_str)
except:
response = test_str
return response
def _invoke_lambda(self, code, event={}, context={}):
# TO DO: context not yet implemented
try:
mycode = "\n".join(['import json',
self.convert(self.code),
self.convert('print(json.dumps(lambda_handler(%s, %s)))' % (self.is_json(self.convert(event)), context))])
def _invoke_lambda(self, code, event=None, context=None):
# TODO: context not yet implemented
if event is None:
event = dict()
if context is None:
context = {}
except Exception as ex:
print("Exception %s", ex)
errored = False
try:
original_stdout = sys.stdout
original_stderr = sys.stderr
codeOut = StringIO()
codeErr = StringIO()
sys.stdout = codeOut
sys.stderr = codeErr
exec(mycode)
exec_err = codeErr.getvalue()
exec_out = codeOut.getvalue()
result = self.convert(exec_out.strip())
if exec_err:
result = "\n".join([exec_out.strip(), self.convert(exec_err)])
except Exception as ex:
errored = True
result = '%s\n\n\nException %s' % (mycode, ex)
finally:
codeErr.close()
codeOut.close()
sys.stdout = original_stdout
sys.stderr = original_stderr
return self.convert(result), errored
# TODO: I believe we can keep the container running and feed events as needed
# also need to hook it up to the other services so it can make kws/s3 etc calls
# Should get invoke_id /RequestId from invovation
env_vars = {
"AWS_LAMBDA_FUNCTION_TIMEOUT": self.timeout,
"AWS_LAMBDA_FUNCTION_NAME": self.function_name,
"AWS_LAMBDA_FUNCTION_MEMORY_SIZE": self.memory_size,
"AWS_LAMBDA_FUNCTION_VERSION": self.version,
"AWS_REGION": self.region,
}
env_vars.update(self.environment_vars)
container = output = exit_code = None
with _DockerDataVolumeContext(self) as data_vol:
try:
run_kwargs = dict(links={'motoserver': 'motoserver'}) if settings.TEST_SERVER_MODE else {}
container = self.docker_client.containers.run(
"lambci/lambda:{}".format(self.run_time),
[self.handler, json.dumps(event)], remove=False,
mem_limit="{}m".format(self.memory_size),
volumes=["{}:/var/task".format(data_vol.name)], environment=env_vars, detach=True, **run_kwargs)
finally:
if container:
exit_code = container.wait()
output = container.logs(stdout=False, stderr=True)
output += container.logs(stdout=True, stderr=False)
container.remove()
output = output.decode('utf-8')
# Send output to "logs" backend
invoke_id = uuid.uuid4().hex
log_stream_name = "{date.year}/{date.month:02d}/{date.day:02d}/[{version}]{invoke_id}".format(
date=datetime.datetime.utcnow(), version=self.version, invoke_id=invoke_id
)
self.logs_backend.create_log_stream(self.logs_group_name, log_stream_name)
log_events = [{'timestamp': unix_time_millis(), "message": line}
for line in output.splitlines()]
self.logs_backend.put_log_events(self.logs_group_name, log_stream_name, log_events, None)
if exit_code != 0:
raise Exception(
'lambda invoke failed output: {}'.format(output))
# strip out RequestId lines
output = os.linesep.join([line for line in self.convert(output).splitlines() if not _stderr_regex.match(line)])
return output, False
except BaseException as e:
traceback.print_exc()
return "error running lambda: {}".format(e), True
def invoke(self, body, request_headers, response_headers):
payload = dict()
if body:
body = json.loads(body)
# Get the invocation type:
res, errored = self._invoke_lambda(code=self.code, event=body)
if request_headers.get("x-amz-invocation-type") == "RequestResponse":
@ -182,7 +348,8 @@ class LambdaFunction(BaseModel):
return result
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
# required
@ -205,17 +372,19 @@ class LambdaFunction(BaseModel):
# this snippet converts this plaintext code to a proper base64-encoded ZIP file.
if 'ZipFile' in properties['Code']:
spec['Code']['ZipFile'] = base64.b64encode(
cls._create_zipfile_from_plaintext_code(spec['Code']['ZipFile']))
cls._create_zipfile_from_plaintext_code(
spec['Code']['ZipFile']))
backend = lambda_backends[region_name]
fn = backend.create_function(spec)
return fn
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
from moto.cloudformation.exceptions import \
UnformattedGetAttTemplateException
if attribute_name == 'Arn':
region = 'us-east-1'
return 'arn:aws:lambda:{0}:123456789012:function:{1}'.format(region, self.function_name)
return 'arn:aws:lambda:{0}:123456789012:function:{1}'.format(
self.region, self.function_name)
raise UnformattedGetAttTemplateException()
@staticmethod
@ -229,7 +398,6 @@ class LambdaFunction(BaseModel):
class EventSourceMapping(BaseModel):
def __init__(self, spec):
# required
self.function_name = spec['FunctionName']
@ -239,10 +407,12 @@ class EventSourceMapping(BaseModel):
# optional
self.batch_size = spec.get('BatchSize', 100)
self.enabled = spec.get('Enabled', True)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp', None)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp',
None)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
spec = {
'FunctionName': properties['FunctionName'],
@ -257,12 +427,12 @@ class EventSourceMapping(BaseModel):
class LambdaVersion(BaseModel):
def __init__(self, spec):
self.version = spec['Version']
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
spec = {
'Version': properties.get('Version')
@ -271,36 +441,99 @@ class LambdaVersion(BaseModel):
class LambdaBackend(BaseBackend):
def __init__(self):
def __init__(self, region_name):
self._functions = {}
self.region_name = region_name
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def has_function(self, function_name):
return function_name in self._functions
def has_function_arn(self, function_arn):
return self.get_function_by_arn(function_arn) is not None
def create_function(self, spec):
fn = LambdaFunction(spec)
fn = LambdaFunction(spec, self.region_name)
self._functions[fn.function_name] = fn
return fn
def get_function(self, function_name):
return self._functions[function_name]
def get_function_by_arn(self, function_arn):
for function in self._functions.values():
if function.function_arn == function_arn:
return function
return None
def delete_function(self, function_name):
del self._functions[function_name]
def list_functions(self):
return self._functions.values()
def send_message(self, function_name, message):
event = {
"Records": [
{
"EventVersion": "1.0",
"EventSubscriptionArn": "arn:aws:sns:EXAMPLE",
"EventSource": "aws:sns",
"Sns": {
"SignatureVersion": "1",
"Timestamp": "1970-01-01T00:00:00.000Z",
"Signature": "EXAMPLE",
"SigningCertUrl": "EXAMPLE",
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
"Message": message,
"MessageAttributes": {
"Test": {
"Type": "String",
"Value": "TestString"
},
"TestBinary": {
"Type": "Binary",
"Value": "TestBinary"
}
},
"Type": "Notification",
"UnsubscribeUrl": "EXAMPLE",
"TopicArn": "arn:aws:sns:EXAMPLE",
"Subject": "TestInvoke"
}
}
]
}
self._functions[function_name].invoke(json.dumps(event), {}, {})
pass
def list_tags(self, resource):
return self.get_function_by_arn(resource).tags
def tag_resource(self, resource, tags):
self.get_function_by_arn(resource).tags.update(tags)
def untag_resource(self, resource, tagKeys):
function = self.get_function_by_arn(resource)
for key in tagKeys:
try:
del function.tags[key]
except KeyError:
pass
# Don't care
def do_validate_s3():
return os.environ.get('VALIDATE_LAMBDA_S3', '') in ['', '1', 'true']
lambda_backends = {}
for region in boto.awslambda.regions():
lambda_backends[region.name] = LambdaBackend()
# Handle us forgotten regions, unless Lambda truly only runs out of US and
for region in ['ap-southeast-2']:
lambda_backends[region] = LambdaBackend()
lambda_backends = {_region.name: LambdaBackend(_region.name)
for _region in boto.awslambda.regions()}
lambda_backends['ap-southeast-2'] = LambdaBackend('ap-southeast-2')

View File

@ -3,6 +3,12 @@ from __future__ import unicode_literals
import json
import re
try:
from urllib import unquote
from urlparse import urlparse, parse_qs
except:
from urllib.parse import unquote, urlparse, parse_qs
from moto.core.responses import BaseResponse
@ -33,6 +39,24 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def invoke_async(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'POST':
return self._invoke_async(request, full_url)
else:
raise ValueError("Cannot handle request")
def tag(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
return self._list_tags(request, full_url)
elif request.method == 'POST':
return self._tag_resource(request, full_url)
elif request.method == 'DELETE':
return self._untag_resource(request, full_url)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def _invoke(self, request, full_url):
response_headers = {}
lambda_backend = self.get_lambda_backend(full_url)
@ -48,6 +72,20 @@ class LambdaResponse(BaseResponse):
else:
return 404, response_headers, "{}"
def _invoke_async(self, request, full_url):
response_headers = {}
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_name = path.split('/')[-3]
if lambda_backend.has_function(function_name):
fn = lambda_backend.get_function(function_name)
fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(0)
return 202, response_headers, ""
else:
return 404, response_headers, "{}"
def _list_functions(self, request, full_url, headers):
lambda_backend = self.get_lambda_backend(full_url)
return 200, {}, json.dumps({
@ -102,3 +140,43 @@ class LambdaResponse(BaseResponse):
return region.group(1)
else:
return self.default_region
def _list_tags(self, request, full_url):
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_arn = unquote(path.split('/')[-1])
if lambda_backend.has_function_arn(function_arn):
function = lambda_backend.get_function_by_arn(function_arn)
return 200, {}, json.dumps(dict(Tags=function.tags))
else:
return 404, {}, "{}"
def _tag_resource(self, request, full_url):
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_arn = unquote(path.split('/')[-1])
spec = json.loads(self.body)
if lambda_backend.has_function_arn(function_arn):
lambda_backend.tag_resource(function_arn, spec['Tags'])
return 200, {}, "{}"
else:
return 404, {}, "{}"
def _untag_resource(self, request, full_url):
lambda_backend = self.get_lambda_backend(full_url)
path = request.path if hasattr(request, 'path') else request.path_url
function_arn = unquote(path.split('/')[-1].split('?')[0])
tag_keys = parse_qs(urlparse(full_url).query)['tagKeys']
if lambda_backend.has_function_arn(function_arn):
lambda_backend.untag_resource(function_arn, tag_keys)
return 204, {}, "{}"
else:
return 404, {}, "{}"

View File

@ -9,6 +9,8 @@ response = LambdaResponse()
url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root,
'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag
}

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals
from moto.acm import acm_backends
from moto.apigateway import apigateway_backends
from moto.autoscaling import autoscaling_backends
from moto.awslambda import lambda_backends
@ -21,7 +22,9 @@ from moto.iam import iam_backends
from moto.instance_metadata import instance_metadata_backends
from moto.kinesis import kinesis_backends
from moto.kms import kms_backends
from moto.logs import logs_backends
from moto.opsworks import opsworks_backends
from moto.polly import polly_backends
from moto.rds2 import rds2_backends
from moto.redshift import redshift_backends
from moto.route53 import route53_backends
@ -31,8 +34,10 @@ from moto.sns import sns_backends
from moto.sqs import sqs_backends
from moto.ssm import ssm_backends
from moto.sts import sts_backends
from moto.xray import xray_backends
BACKENDS = {
'acm': acm_backends,
'apigateway': apigateway_backends,
'autoscaling': autoscaling_backends,
'cloudformation': cloudformation_backends,
@ -51,9 +56,11 @@ BACKENDS = {
'iam': iam_backends,
'moto_api': moto_api_backends,
'instance_metadata': instance_metadata_backends,
'opsworks': opsworks_backends,
'logs': logs_backends,
'kinesis': kinesis_backends,
'kms': kms_backends,
'opsworks': opsworks_backends,
'polly': polly_backends,
'redshift': redshift_backends,
'rds': rds2_backends,
's3': s3_backends,
@ -65,6 +72,7 @@ BACKENDS = {
'sts': sts_backends,
'route53': route53_backends,
'lambda': lambda_backends,
'xray': xray_backends
}

View File

@ -9,7 +9,7 @@ from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from .parsing import ResourceMap, OutputMap
from .utils import generate_stack_id
from .utils import generate_stack_id, yaml_tag_constructor
from .exceptions import ValidationError
@ -74,6 +74,7 @@ class FakeStack(BaseModel):
))
def _parse_template(self):
yaml.add_multi_constructor('', yaml_tag_constructor)
try:
self.template_dict = yaml.load(self.template)
except yaml.parser.ParserError:

View File

@ -391,8 +391,7 @@ LIST_STACKS_RESOURCES_RESPONSE = """<ListStackResourcesResponse>
GET_TEMPLATE_RESPONSE_TEMPLATE = """<GetTemplateResponse>
<GetTemplateResult>
<TemplateBody>{{ stack.template }}
</TemplateBody>
<TemplateBody>{{ stack.template }}</TemplateBody>
</GetTemplateResult>
<ResponseMetadata>
<RequestId>b9b4b068-3a41-11e5-94eb-example</RequestId>

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import uuid
import six
import random
import yaml
def generate_stack_id(stack_name):
@ -13,3 +14,22 @@ def random_suffix():
size = 12
chars = list(range(10)) + ['A-Z']
return ''.join(six.text_type(random.choice(chars)) for x in range(size))
def yaml_tag_constructor(loader, tag, node):
"""convert shorthand intrinsic function to full name
"""
def _f(loader, tag, node):
if tag == '!GetAtt':
return node.value.split('.')
elif type(node) == yaml.SequenceNode:
return loader.construct_sequence(node)
else:
return node.value
if tag == '!Ref':
key = 'Ref'
else:
key = 'Fn::{}'.format(tag[1:])
return {key: _f(loader, tag, node)}

View File

@ -2,6 +2,11 @@ from moto.core import BaseBackend, BaseModel
import boto.ec2.cloudwatch
import datetime
from .utils import make_arn_for_dashboard
DEFAULT_ACCOUNT_ID = 123456789012
class Dimension(object):
@ -44,10 +49,34 @@ class MetricDatum(BaseModel):
'value']) for dimension in dimensions]
class Dashboard(BaseModel):
def __init__(self, name, body):
# Guaranteed to be unique for now as the name is also the key of a dictionary where they are stored
self.arn = make_arn_for_dashboard(DEFAULT_ACCOUNT_ID, name)
self.name = name
self.body = body
self.last_modified = datetime.datetime.now()
@property
def last_modified_iso(self):
return self.last_modified.isoformat()
@property
def size(self):
return len(self)
def __len__(self):
return len(self.body)
def __repr__(self):
return '<CloudWatchDashboard {0}>'.format(self.name)
class CloudWatchBackend(BaseBackend):
def __init__(self):
self.alarms = {}
self.dashboards = {}
self.metric_data = []
def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods,
@ -110,6 +139,31 @@ class CloudWatchBackend(BaseBackend):
def get_all_metrics(self):
return self.metric_data
def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body)
def list_dashboards(self, prefix=''):
for key, value in self.dashboards.items():
if key.startswith(prefix):
yield value
def delete_dashboards(self, dashboards):
to_delete = set(dashboards)
all_dashboards = set(self.dashboards.keys())
left_over = to_delete - all_dashboards
if len(left_over) > 0:
# Some dashboards are not found
return False, 'The specified dashboard does not exist. [{0}]'.format(', '.join(left_over))
for dashboard in to_delete:
del self.dashboards[dashboard]
return True, None
def get_dashboard(self, dashboard):
return self.dashboards.get(dashboard)
class LogGroup(BaseModel):

View File

@ -1,9 +1,18 @@
import json
from moto.core.responses import BaseResponse
from .models import cloudwatch_backends
class CloudWatchResponse(BaseResponse):
@property
def cloudwatch_backend(self):
return cloudwatch_backends[self.region]
def _error(self, code, message, status=400):
template = self.response_template(ERROR_RESPONSE_TEMPLATE)
return template.render(code=code, message=message), dict(status=status)
def put_metric_alarm(self):
name = self._get_param('AlarmName')
namespace = self._get_param('Namespace')
@ -20,15 +29,14 @@ class CloudWatchResponse(BaseResponse):
insufficient_data_actions = self._get_multi_param(
"InsufficientDataActions.member")
unit = self._get_param('Unit')
cloudwatch_backend = cloudwatch_backends[self.region]
alarm = cloudwatch_backend.put_metric_alarm(name, namespace, metric_name,
comparison_operator,
evaluation_periods, period,
threshold, statistic,
description, dimensions,
alarm_actions, ok_actions,
insufficient_data_actions,
unit)
alarm = self.cloudwatch_backend.put_metric_alarm(name, namespace, metric_name,
comparison_operator,
evaluation_periods, period,
threshold, statistic,
description, dimensions,
alarm_actions, ok_actions,
insufficient_data_actions,
unit)
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm)
@ -37,28 +45,26 @@ class CloudWatchResponse(BaseResponse):
alarm_name_prefix = self._get_param('AlarmNamePrefix')
alarm_names = self._get_multi_param('AlarmNames.member')
state_value = self._get_param('StateValue')
cloudwatch_backend = cloudwatch_backends[self.region]
if action_prefix:
alarms = cloudwatch_backend.get_alarms_by_action_prefix(
alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(
action_prefix)
elif alarm_name_prefix:
alarms = cloudwatch_backend.get_alarms_by_alarm_name_prefix(
alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix(
alarm_name_prefix)
elif alarm_names:
alarms = cloudwatch_backend.get_alarms_by_alarm_names(alarm_names)
alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names)
elif state_value:
alarms = cloudwatch_backend.get_alarms_by_state_value(state_value)
alarms = self.cloudwatch_backend.get_alarms_by_state_value(state_value)
else:
alarms = cloudwatch_backend.get_all_alarms()
alarms = self.cloudwatch_backend.get_all_alarms()
template = self.response_template(DESCRIBE_ALARMS_TEMPLATE)
return template.render(alarms=alarms)
def delete_alarms(self):
alarm_names = self._get_multi_param('AlarmNames.member')
cloudwatch_backend = cloudwatch_backends[self.region]
cloudwatch_backend.delete_alarms(alarm_names)
self.cloudwatch_backend.delete_alarms(alarm_names)
template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE)
return template.render()
@ -89,17 +95,77 @@ class CloudWatchResponse(BaseResponse):
dimension_index += 1
metric_data.append([metric_name, value, dimensions])
metric_index += 1
cloudwatch_backend = cloudwatch_backends[self.region]
cloudwatch_backend.put_metric_data(namespace, metric_data)
self.cloudwatch_backend.put_metric_data(namespace, metric_data)
template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
return template.render()
def list_metrics(self):
cloudwatch_backend = cloudwatch_backends[self.region]
metrics = cloudwatch_backend.get_all_metrics()
metrics = self.cloudwatch_backend.get_all_metrics()
template = self.response_template(LIST_METRICS_TEMPLATE)
return template.render(metrics=metrics)
def delete_dashboards(self):
dashboards = self._get_multi_param('DashboardNames.member')
if dashboards is None:
return self._error('InvalidParameterValue', 'Need at least 1 dashboard')
status, error = self.cloudwatch_backend.delete_dashboards(dashboards)
if not status:
return self._error('ResourceNotFound', error)
template = self.response_template(DELETE_DASHBOARD_TEMPLATE)
return template.render()
def describe_alarm_history(self):
raise NotImplementedError()
def describe_alarms_for_metric(self):
raise NotImplementedError()
def disable_alarm_actions(self):
raise NotImplementedError()
def enable_alarm_actions(self):
raise NotImplementedError()
def get_dashboard(self):
dashboard_name = self._get_param('DashboardName')
dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name)
if dashboard is None:
return self._error('ResourceNotFound', 'Dashboard does not exist')
template = self.response_template(GET_DASHBOARD_TEMPLATE)
return template.render(dashboard=dashboard)
def get_metric_statistics(self):
raise NotImplementedError()
def list_dashboards(self):
prefix = self._get_param('DashboardNamePrefix', '')
dashboards = self.cloudwatch_backend.list_dashboards(prefix)
template = self.response_template(LIST_DASHBOARD_RESPONSE)
return template.render(dashboards=dashboards)
def put_dashboard(self):
name = self._get_param('DashboardName')
body = self._get_param('DashboardBody')
try:
json.loads(body)
except ValueError:
return self._error('InvalidParameterInput', 'Body is invalid JSON')
self.cloudwatch_backend.put_dashboard(name, body)
template = self.response_template(PUT_DASHBOARD_RESPONSE)
return template.render()
def set_alarm_state(self):
raise NotImplementedError()
PUT_METRIC_ALARM_TEMPLATE = """<PutMetricAlarmResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
@ -199,3 +265,58 @@ LIST_METRICS_TEMPLATE = """<ListMetricsResponse xmlns="http://monitoring.amazona
</NextToken>
</ListMetricsResult>
</ListMetricsResponse>"""
PUT_DASHBOARD_RESPONSE = """<PutDashboardResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<PutDashboardResult>
<DashboardValidationMessages/>
</PutDashboardResult>
<ResponseMetadata>
<RequestId>44b1d4d8-9fa3-11e7-8ad3-41b86ac5e49e</RequestId>
</ResponseMetadata>
</PutDashboardResponse>"""
LIST_DASHBOARD_RESPONSE = """<ListDashboardsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ListDashboardsResult>
<DashboardEntries>
{% for dashboard in dashboards %}
<member>
<DashboardArn>{{ dashboard.arn }}</DashboardArn>
<LastModified>{{ dashboard.last_modified_iso }}</LastModified>
<Size>{{ dashboard.size }}</Size>
<DashboardName>{{ dashboard.name }}</DashboardName>
</member>
{% endfor %}
</DashboardEntries>
</ListDashboardsResult>
<ResponseMetadata>
<RequestId>c3773873-9fa5-11e7-b315-31fcc9275d62</RequestId>
</ResponseMetadata>
</ListDashboardsResponse>"""
DELETE_DASHBOARD_TEMPLATE = """<DeleteDashboardsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<DeleteDashboardsResult/>
<ResponseMetadata>
<RequestId>68d1dc8c-9faa-11e7-a694-df2715690df2</RequestId>
</ResponseMetadata>
</DeleteDashboardsResponse>"""
GET_DASHBOARD_TEMPLATE = """<GetDashboardResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<GetDashboardResult>
<DashboardArn>{{ dashboard.arn }}</DashboardArn>
<DashboardBody>{{ dashboard.body }}</DashboardBody>
<DashboardName>{{ dashboard.name }}</DashboardName>
</GetDashboardResult>
<ResponseMetadata>
<RequestId>e3c16bb0-9faa-11e7-b315-31fcc9275d62</RequestId>
</ResponseMetadata>
</GetDashboardResponse>
"""
ERROR_RESPONSE_TEMPLATE = """<ErrorResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<Error>
<Type>Sender</Type>
<Code>{{ code }}</Code>
<Message>{{ message }}</Message>
</Error>
<RequestId>5e45fd1e-9fa3-11e7-b720-89e8821d38c4</RequestId>
</ErrorResponse>"""

5
moto/cloudwatch/utils.py Normal file
View File

@ -0,0 +1,5 @@
from __future__ import unicode_literals
def make_arn_for_dashboard(account_id, name):
return "arn:aws:cloudwatch::{0}dashboard/{1}".format(account_id, name)

View File

@ -167,7 +167,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
match = re.search(self.region_regex, full_url)
if match:
region = match.group(1)
elif 'Authorization' in request.headers:
elif 'Authorization' in request.headers and 'AWS4' in request.headers['Authorization']:
region = request.headers['Authorization'].split(",")[
0].split("/")[2]
else:
@ -178,8 +178,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.setup_class(request, full_url, headers)
return self.call_action()
def call_action(self):
headers = self.response_headers
def _get_action(self):
action = self.querystring.get('Action', [""])[0]
if not action: # Some services use a header for the action
# Headers are case-insensitive. Probably a better way to do this.
@ -188,7 +187,11 @@ class BaseResponse(_TemplateEnvironmentMixin):
if match:
action = match.split(".")[-1]
action = camelcase_to_underscores(action)
return action
def call_action(self):
headers = self.response_headers
action = camelcase_to_underscores(self._get_action())
method_names = method_names_from_class(self.__class__)
if action in method_names:
method = getattr(self, action)
@ -310,7 +313,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
param_index += 1
return results
def _get_map_prefix(self, param_prefix):
def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'):
results = {}
param_index = 1
while 1:
@ -319,9 +322,9 @@ class BaseResponse(_TemplateEnvironmentMixin):
k, v = None, None
for key, value in self.querystring.items():
if key.startswith(index_prefix):
if key.endswith('.key'):
if key.endswith(key_end):
k = value[0]
elif key.endswith('.value'):
elif key.endswith(value_end):
v = value[0]
if not (k and v):

View File

@ -7,33 +7,6 @@ from moto.core.utils import camelcase_to_underscores
from .models import dynamodb_backend, dynamo_json_dump
GET_SESSION_TOKEN_RESULT = """
<GetSessionTokenResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetSessionTokenResult>
<Credentials>
<SessionToken>
AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/L
To6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3z
rkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtp
Z3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE
</SessionToken>
<SecretAccessKey>
wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY
</SecretAccessKey>
<Expiration>2011-07-11T19:55:29.611Z</Expiration>
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId>
</Credentials>
</GetSessionTokenResult>
<ResponseMetadata>
<RequestId>58c5dbae-abef-11e0-8cfe-09039844ac7d</RequestId>
</ResponseMetadata>
</GetSessionTokenResponse>"""
def sts_handler():
return GET_SESSION_TOKEN_RESULT
class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers):
@ -51,11 +24,7 @@ class DynamoHandler(BaseResponse):
return status, self.response_headers, dynamo_json_dump({'__type': type_})
def call_action(self):
body = self.body
if 'GetSessionToken' in body:
return 200, self.response_headers, sts_handler()
self.body = json.loads(body or '{}')
self.body = json.loads(self.body or '{}')
endpoint = self.get_endpoint_name(self.headers)
if endpoint:
endpoint = camelcase_to_underscores(endpoint)

View File

@ -2,8 +2,7 @@ from __future__ import unicode_literals
from .responses import DynamoHandler
url_bases = [
"https?://dynamodb.(.+).amazonaws.com",
"https?://sts.amazonaws.com",
"https?://dynamodb.(.+).amazonaws.com"
]
url_paths = {

View File

@ -57,7 +57,7 @@ class DynamoType(object):
@property
def cast_value(self):
if self.type == 'N':
if self.is_number():
try:
return int(self.value)
except ValueError:
@ -76,6 +76,15 @@ class DynamoType(object):
comparison_func = get_comparison_func(range_comparison)
return comparison_func(self.cast_value, *range_values)
def is_number(self):
return self.type == 'N'
def is_set(self):
return self.type == 'SS' or self.type == 'NS' or self.type == 'BS'
def same_type(self, other):
return self.type == other.type
class Item(BaseModel):
@ -118,10 +127,11 @@ class Item(BaseModel):
def update(self, update_expression, expression_attribute_names, expression_attribute_values):
# Update subexpressions are identifiable by the operator keyword, so split on that and
# get rid of the empty leading string.
parts = [p for p in re.split(r'\b(SET|REMOVE|ADD|DELETE)\b', update_expression) if p]
parts = [p for p in re.split(r'\b(SET|REMOVE|ADD|DELETE)\b', update_expression, flags=re.I) if p]
# make sure that we correctly found only operator/value pairs
assert len(parts) % 2 == 0, "Mismatched operators and values in update expression: '{}'".format(update_expression)
for action, valstr in zip(parts[:-1:2], parts[1::2]):
action = action.upper()
values = valstr.split(',')
for value in values:
# A Real value
@ -139,6 +149,55 @@ class Item(BaseModel):
self.attrs[key] = DynamoType(expression_attribute_values[value])
else:
self.attrs[key] = DynamoType({"S": value})
elif action == 'ADD':
key, value = value.split(" ", 1)
key = key.strip()
value_str = value.strip()
if value_str in expression_attribute_values:
dyn_value = DynamoType(expression_attribute_values[value])
else:
raise TypeError
# Handle adding numbers - value gets added to existing value,
# or added to 0 if it doesn't exist yet
if dyn_value.is_number():
existing = self.attrs.get(key, DynamoType({"N": '0'}))
if not existing.same_type(dyn_value):
raise TypeError()
self.attrs[key] = DynamoType({"N": str(
decimal.Decimal(existing.value) +
decimal.Decimal(dyn_value.value)
)})
# Handle adding sets - value is added to the set, or set is
# created with only this value if it doesn't exist yet
# New value must be of same set type as previous value
elif dyn_value.is_set():
existing = self.attrs.get(key, DynamoType({dyn_value.type: {}}))
if not existing.same_type(dyn_value):
raise TypeError()
new_set = set(existing.value).union(dyn_value.value)
self.attrs[key] = DynamoType({existing.type: list(new_set)})
else: # Number and Sets are the only supported types for ADD
raise TypeError
elif action == 'DELETE':
key, value = value.split(" ", 1)
key = key.strip()
value_str = value.strip()
if value_str in expression_attribute_values:
dyn_value = DynamoType(expression_attribute_values[value])
else:
raise TypeError
if not dyn_value.is_set():
raise TypeError
existing = self.attrs.get(key, None)
if existing:
if not existing.same_type(dyn_value):
raise TypeError
new_set = set(existing.value).difference(dyn_value.value)
self.attrs[key] = DynamoType({existing.type: list(new_set)})
else:
raise NotImplementedError('{} update action not yet supported'.format(action))
@ -171,6 +230,12 @@ class Item(BaseModel):
decimal.Decimal(existing.value) +
decimal.Decimal(new_value)
)})
elif set(update_action['Value'].keys()) == set(['SS']):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
new_set = set(existing.value).union(set(new_value))
self.attrs[attribute_name] = DynamoType({
"SS": list(new_set)
})
else:
# TODO: implement other data types
raise NotImplementedError(
@ -347,7 +412,8 @@ class Table(BaseModel):
return None
def query(self, hash_key, range_comparison, range_objs, limit,
exclusive_start_key, scan_index_forward, index_name=None, **filter_kwargs):
exclusive_start_key, scan_index_forward, projection_expression,
index_name=None, **filter_kwargs):
results = []
if index_name:
all_indexes = (self.global_indexes or []) + (self.indexes or [])
@ -418,6 +484,13 @@ class Table(BaseModel):
else:
results.sort(key=lambda item: item.range_key)
if projection_expression:
expressions = [x.strip() for x in projection_expression.split(',')]
for result in possible_results:
for attr in list(result.attrs):
if attr not in expressions:
result.attrs.pop(attr)
if scan_index_forward is False:
results.reverse()
@ -613,7 +686,7 @@ class DynamoDBBackend(BaseBackend):
return table.get_item(hash_key, range_key)
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts,
limit, exclusive_start_key, scan_index_forward, index_name=None, **filter_kwargs):
limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, **filter_kwargs):
table = self.tables.get(table_name)
if not table:
return None, None
@ -623,7 +696,7 @@ class DynamoDBBackend(BaseBackend):
for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, index_name, **filter_kwargs)
exclusive_start_key, scan_index_forward, projection_expression, index_name, **filter_kwargs)
def scan(self, table_name, filters, limit, exclusive_start_key):
table = self.tables.get(table_name)

View File

@ -8,33 +8,6 @@ from moto.core.utils import camelcase_to_underscores
from .models import dynamodb_backend2, dynamo_json_dump
GET_SESSION_TOKEN_RESULT = """
<GetSessionTokenResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetSessionTokenResult>
<Credentials>
<SessionToken>
AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/L
To6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3z
rkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtp
Z3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE
</SessionToken>
<SecretAccessKey>
wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY
</SecretAccessKey>
<Expiration>2011-07-11T19:55:29.611Z</Expiration>
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId>
</Credentials>
</GetSessionTokenResult>
<ResponseMetadata>
<RequestId>58c5dbae-abef-11e0-8cfe-09039844ac7d</RequestId>
</ResponseMetadata>
</GetSessionTokenResponse>"""
def sts_handler():
return GET_SESSION_TOKEN_RESULT
class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers):
@ -48,15 +21,11 @@ class DynamoHandler(BaseResponse):
if match:
return match.split(".")[1]
def error(self, type_, status=400):
return status, self.response_headers, dynamo_json_dump({'__type': type_})
def error(self, type_, message, status=400):
return status, self.response_headers, dynamo_json_dump({'__type': type_, 'message': message})
def call_action(self):
body = self.body
if 'GetSessionToken' in body:
return 200, self.response_headers, sts_handler()
self.body = json.loads(body or '{}')
self.body = json.loads(self.body or '{}')
endpoint = self.get_endpoint_name(self.headers)
if endpoint:
endpoint = camelcase_to_underscores(endpoint)
@ -113,7 +82,7 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(table.describe())
else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceInUseException'
return self.error(er)
return self.error(er, 'Resource in use')
def delete_table(self):
name = self.body['TableName']
@ -122,7 +91,7 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(table.describe())
else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er)
return self.error(er, 'Requested resource not found')
def tag_resource(self):
tags = self.body['Tags']
@ -151,7 +120,7 @@ class DynamoHandler(BaseResponse):
return json.dumps({'Tags': tags_resp})
except AttributeError:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er)
return self.error(er, 'Requested resource not found')
def update_table(self):
name = self.body['TableName']
@ -169,12 +138,24 @@ class DynamoHandler(BaseResponse):
table = dynamodb_backend2.tables[name]
except KeyError:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er)
return self.error(er, 'Requested resource not found')
return dynamo_json_dump(table.describe(base_key='Table'))
def put_item(self):
name = self.body['TableName']
item = self.body['Item']
res = re.search('\"\"', json.dumps(item))
if res:
er = 'com.amazonaws.dynamodb.v20111205#ValidationException'
return (400,
{'server': 'amazon.com'},
dynamo_json_dump({'__type': er,
'message': ('One or more parameter values were '
'invalid: An AttributeValue may not '
'contain an empty string')}
))
overwrite = 'Expected' not in self.body
if not overwrite:
expected = self.body['Expected']
@ -209,15 +190,18 @@ class DynamoHandler(BaseResponse):
name, item, expected, overwrite)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er)
return self.error(er, 'A condition specified in the operation could not be evaluated.')
if result:
item_dict = result.to_json()
item_dict['ConsumedCapacityUnits'] = 1
item_dict['ConsumedCapacity'] = {
'TableName': name,
'CapacityUnits': 1
}
return dynamo_json_dump(item_dict)
else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er)
return self.error(er, 'Requested resource not found')
def batch_write_item(self):
table_batches = self.body['RequestItems']
@ -254,15 +238,17 @@ class DynamoHandler(BaseResponse):
item = dynamodb_backend2.get_item(name, key)
except ValueError:
er = 'com.amazon.coral.validate#ValidationException'
return self.error(er, status=400)
return self.error(er, 'Validation Exception')
if item:
item_dict = item.describe_attrs(attributes=None)
item_dict['ConsumedCapacityUnits'] = 0.5
item_dict['ConsumedCapacity'] = {
'TableName': name,
'CapacityUnits': 0.5
}
return dynamo_json_dump(item_dict)
else:
# Item not found
er = '{}'
return self.error(er, status=200)
return 200, self.response_headers, '{}'
def batch_get_item(self):
table_batches = self.body['RequestItems']
@ -296,11 +282,26 @@ class DynamoHandler(BaseResponse):
name = self.body['TableName']
# {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}}
key_condition_expression = self.body.get('KeyConditionExpression')
projection_expression = self.body.get('ProjectionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames')
if projection_expression and expression_attribute_names:
expressions = [x.strip() for x in projection_expression.split(',')]
for expression in expressions:
if expression in expression_attribute_names:
projection_expression = projection_expression.replace(expression, expression_attribute_names[expression])
filter_kwargs = {}
if key_condition_expression:
value_alias_map = self.body['ExpressionAttributeValues']
table = dynamodb_backend2.get_table(name)
# If table does not exist
if table is None:
return self.error('com.amazonaws.dynamodb.v20120810#ResourceNotFoundException',
'Requested resource not found')
index_name = self.body.get('IndexName')
if index_name:
all_indexes = (table.global_indexes or []) + \
@ -369,7 +370,7 @@ class DynamoHandler(BaseResponse):
filter_kwargs[key] = value
if hash_key_name is None:
er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException"
return self.error(er)
return self.error(er, 'Requested resource not found')
hash_key = key_conditions[hash_key_name][
'AttributeValueList'][0]
if len(key_conditions) == 1:
@ -378,7 +379,7 @@ class DynamoHandler(BaseResponse):
else:
if range_key_name is None and not filter_kwargs:
er = "com.amazon.coral.validate#ValidationException"
return self.error(er)
return self.error(er, 'Validation Exception')
else:
range_condition = key_conditions.get(range_key_name)
if range_condition:
@ -397,16 +398,20 @@ class DynamoHandler(BaseResponse):
scan_index_forward = self.body.get("ScanIndexForward")
items, scanned_count, last_evaluated_key = dynamodb_backend2.query(
name, hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, index_name=index_name, **filter_kwargs)
exclusive_start_key, scan_index_forward, projection_expression, index_name=index_name, **filter_kwargs)
if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er)
return self.error(er, 'Requested resource not found')
result = {
"Count": len(items),
"ConsumedCapacityUnits": 1,
'ConsumedCapacity': {
'TableName': name,
'CapacityUnits': 1,
},
"ScannedCount": scanned_count
}
if self.body.get('Select', '').upper() != 'COUNT':
result["Items"] = [item.attrs for item in items]
@ -436,12 +441,15 @@ class DynamoHandler(BaseResponse):
if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er)
return self.error(er, 'Requested resource not found')
result = {
"Count": len(items),
"Items": [item.attrs for item in items],
"ConsumedCapacityUnits": 1,
'ConsumedCapacity': {
'TableName': name,
'CapacityUnits': 1,
},
"ScannedCount": scanned_count
}
if last_evaluated_key is not None:
@ -455,7 +463,7 @@ class DynamoHandler(BaseResponse):
table = dynamodb_backend2.get_table(name)
if not table:
er = 'com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException'
return self.error(er)
return self.error(er, 'A condition specified in the operation could not be evaluated.')
item = dynamodb_backend2.delete_item(name, keys)
if item and return_values == 'ALL_OLD':
@ -515,10 +523,16 @@ class DynamoHandler(BaseResponse):
expected)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er)
return self.error(er, 'A condition specified in the operation could not be evaluated.')
except TypeError:
er = 'com.amazonaws.dynamodb.v20111205#ValidationException'
return self.error(er, 'Validation Exception')
item_dict = item.to_json()
item_dict['ConsumedCapacityUnits'] = 0.5
item_dict['ConsumedCapacity'] = {
'TableName': name,
'CapacityUnits': 0.5
}
if not existing_item:
item_dict['Attributes'] = {}

View File

@ -2,8 +2,7 @@ from __future__ import unicode_literals
from .responses import DynamoHandler
url_bases = [
"https?://dynamodb.(.+).amazonaws.com",
"https?://sts.amazonaws.com",
"https?://dynamodb.(.+).amazonaws.com"
]
url_paths = {

View File

@ -375,3 +375,20 @@ class RulesPerSecurityGroupLimitExceededError(EC2ClientError):
"RulesPerSecurityGroupLimitExceeded",
'The maximum number of rules per security group '
'has been reached.')
class MotoNotImplementedError(NotImplementedError):
def __init__(self, blurb):
super(MotoNotImplementedError, self).__init__(
"{0} has not been implemented in Moto yet."
" Feel free to open an issue at"
" https://github.com/spulec/moto/issues".format(blurb))
class FilterNotImplementedError(MotoNotImplementedError):
def __init__(self, filter_name, method_name):
super(FilterNotImplementedError, self).__init__(
"The filter '{0}' for {1}".format(
filter_name, method_name))

View File

@ -2,6 +2,8 @@ from __future__ import unicode_literals
import copy
import itertools
import json
import os
import re
import six
@ -62,6 +64,8 @@ from .exceptions import (
InvalidVpnConnectionIdError,
InvalidCustomerGatewayIdError,
RulesPerSecurityGroupLimitExceededError,
MotoNotImplementedError,
FilterNotImplementedError
)
from .utils import (
EC2_RESOURCE_TO_PREFIX,
@ -107,6 +111,9 @@ from .utils import (
is_tag_filter,
)
RESOURCES_DIR = os.path.join(os.path.dirname(__file__), 'resources')
INSTANCE_TYPES = json.load(open(os.path.join(RESOURCES_DIR, 'instance_types.json'), 'r'))
def utc_date_and_time():
return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z')
@ -144,7 +151,7 @@ class TaggedEC2Resource(BaseModel):
for key, value in tag_map.items():
self.ec2_backend.create_tags([self.id], {key: value})
def get_filter_value(self, filter_name):
def get_filter_value(self, filter_name, method_name=None):
tags = self.get_tags()
if filter_name.startswith('tag:'):
@ -154,12 +161,12 @@ class TaggedEC2Resource(BaseModel):
return tag['value']
return ''
if filter_name == 'tag-key':
elif filter_name == 'tag-key':
return [tag['key'] for tag in tags]
if filter_name == 'tag-value':
elif filter_name == 'tag-value':
return [tag['value'] for tag in tags]
else:
raise FilterNotImplementedError(filter_name, method_name)
class NetworkInterface(TaggedEC2Resource):
@ -261,17 +268,9 @@ class NetworkInterface(TaggedEC2Resource):
return [group.id for group in self._group_set]
elif filter_name == 'availability-zone':
return self.subnet.availability_zone
filter_value = super(
NetworkInterface, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeNetworkInterfaces".format(
filter_name)
)
return filter_value
else:
return super(NetworkInterface, self).get_filter_value(
filter_name, 'DescribeNetworkInterfaces')
class NetworkInterfaceBackend(object):
@ -366,6 +365,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.user_data = user_data
self.security_groups = security_groups
self.instance_type = kwargs.get("instance_type", "m1.small")
self.region_name = kwargs.get("region_name", "us-east-1")
placement = kwargs.get("placement", None)
self.vpc_id = None
self.subnet_id = kwargs.get("subnet_id")
@ -433,7 +433,11 @@ class Instance(TaggedEC2Resource, BotoInstance):
@property
def private_dns(self):
return "ip-{0}.ec2.internal".format(self.private_ip)
formatted_ip = self.private_ip.replace('.', '-')
if self.region_name == "us-east-1":
return "ip-{0}.ec2.internal".format(formatted_ip)
else:
return "ip-{0}.{1}.compute.internal".format(formatted_ip, self.region_name)
@property
def public_ip(self):
@ -442,7 +446,11 @@ class Instance(TaggedEC2Resource, BotoInstance):
@property
def public_dns(self):
if self.public_ip:
return "ec2-{0}.compute-1.amazonaws.com".format(self.public_ip)
formatted_ip = self.public_ip.replace('.', '-')
if self.region_name == "us-east-1":
return "ec2-{0}.compute-1.amazonaws.com".format(formatted_ip)
else:
return "ec2-{0}.{1}.compute.amazonaws.com".format(formatted_ip, self.region_name)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -581,10 +589,6 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.attach_eni(use_nic, device_index)
def set_ip(self, ip_address):
# Should we be creating a new ENI?
self.nics[0].public_ip = ip_address
def attach_eni(self, eni, device_index):
device_index = int(device_index)
self.nics[device_index] = eni
@ -786,16 +790,31 @@ class InstanceBackend(object):
return reservations
class KeyPair(object):
def __init__(self, name, fingerprint, material):
self.name = name
self.fingerprint = fingerprint
self.material = material
def get_filter_value(self, filter_name):
if filter_name == 'key-name':
return self.name
elif filter_name == 'fingerprint':
return self.fingerprint
else:
raise FilterNotImplementedError(filter_name, 'DescribeKeyPairs')
class KeyPairBackend(object):
def __init__(self):
self.keypairs = defaultdict(dict)
self.keypairs = {}
super(KeyPairBackend, self).__init__()
def create_key_pair(self, name):
if name in self.keypairs:
raise InvalidKeyPairDuplicateError(name)
self.keypairs[name] = keypair = random_key_pair()
keypair['name'] = name
keypair = KeyPair(name, **random_key_pair())
self.keypairs[name] = keypair
return keypair
def delete_key_pair(self, name):
@ -803,24 +822,27 @@ class KeyPairBackend(object):
self.keypairs.pop(name)
return True
def describe_key_pairs(self, filter_names=None):
def describe_key_pairs(self, key_names=None, filters=None):
results = []
for name, keypair in self.keypairs.items():
if not filter_names or name in filter_names:
keypair['name'] = name
results.append(keypair)
if key_names:
results = [keypair for keypair in self.keypairs.values()
if keypair.name in key_names]
if len(key_names) > len(results):
unknown_keys = set(key_names) - set(results)
raise InvalidKeyPairNameError(unknown_keys)
else:
results = self.keypairs.values()
# TODO: Trim error message down to specific invalid name.
if filter_names and len(filter_names) > len(results):
raise InvalidKeyPairNameError(filter_names)
return results
if filters:
return generic_filter(filters, results)
else:
return results
def import_key_pair(self, key_name, public_key_material):
if key_name in self.keypairs:
raise InvalidKeyPairDuplicateError(key_name)
self.keypairs[key_name] = keypair = random_key_pair()
keypair['name'] = key_name
keypair = KeyPair(key_name, **random_key_pair())
self.keypairs[key_name] = keypair
return keypair
@ -1018,14 +1040,9 @@ class Ami(TaggedEC2Resource):
return self.state
elif filter_name == 'name':
return self.name
filter_value = super(Ami, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeImages".format(filter_name))
return filter_value
else:
return super(Ami, self).get_filter_value(
filter_name, 'DescribeImages')
class AmiBackend(object):
@ -1348,22 +1365,25 @@ class SecurityGroupBackend(object):
return group
def describe_security_groups(self, group_ids=None, groupnames=None, filters=None):
all_groups = itertools.chain(*[x.values()
for x in self.groups.values()])
groups = []
matches = itertools.chain(*[x.values()
for x in self.groups.values()])
if group_ids:
matches = [grp for grp in matches
if grp.id in group_ids]
if len(group_ids) > len(matches):
unknown_ids = set(group_ids) - set(matches)
raise InvalidSecurityGroupNotFoundError(unknown_ids)
if groupnames:
matches = [grp for grp in matches
if grp.name in groupnames]
if len(groupnames) > len(matches):
unknown_names = set(groupnames) - set(matches)
raise InvalidSecurityGroupNotFoundError(unknown_names)
if filters:
matches = [grp for grp in matches
if grp.matches_filters(filters)]
if group_ids or groupnames or filters:
for group in all_groups:
if ((group_ids and group.id not in group_ids) or
(groupnames and group.name not in groupnames)):
continue
if filters and not group.matches_filters(filters):
continue
groups.append(group)
else:
groups = all_groups
return groups
return matches
def _delete_security_group(self, vpc_id, group_id):
if self.groups[vpc_id][group_id].enis:
@ -1682,43 +1702,31 @@ class Volume(TaggedEC2Resource):
return 'available'
def get_filter_value(self, filter_name):
if filter_name.startswith('attachment') and not self.attachment:
return None
if filter_name == 'attachment.attach-time':
elif filter_name == 'attachment.attach-time':
return self.attachment.attach_time
if filter_name == 'attachment.device':
elif filter_name == 'attachment.device':
return self.attachment.device
if filter_name == 'attachment.instance-id':
elif filter_name == 'attachment.instance-id':
return self.attachment.instance.id
if filter_name == 'attachment.status':
elif filter_name == 'attachment.status':
return self.attachment.status
if filter_name == 'create-time':
elif filter_name == 'create-time':
return self.create_time
if filter_name == 'size':
elif filter_name == 'size':
return self.size
if filter_name == 'snapshot-id':
elif filter_name == 'snapshot-id':
return self.snapshot_id
if filter_name == 'status':
elif filter_name == 'status':
return self.status
if filter_name == 'volume-id':
elif filter_name == 'volume-id':
return self.id
if filter_name == 'encrypted':
elif filter_name == 'encrypted':
return str(self.encrypted).lower()
filter_value = super(Volume, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeVolumes".format(filter_name))
return filter_value
else:
return super(Volume, self).get_filter_value(
filter_name, 'DescribeVolumes')
class Snapshot(TaggedEC2Resource):
@ -1733,35 +1741,23 @@ class Snapshot(TaggedEC2Resource):
self.encrypted = encrypted
def get_filter_value(self, filter_name):
if filter_name == 'description':
return self.description
if filter_name == 'snapshot-id':
elif filter_name == 'snapshot-id':
return self.id
if filter_name == 'start-time':
elif filter_name == 'start-time':
return self.start_time
if filter_name == 'volume-id':
elif filter_name == 'volume-id':
return self.volume.id
if filter_name == 'volume-size':
elif filter_name == 'volume-size':
return self.volume.size
if filter_name == 'encrypted':
elif filter_name == 'encrypted':
return str(self.encrypted).lower()
if filter_name == 'status':
elif filter_name == 'status':
return self.status
filter_value = super(Snapshot, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeSnapshots".format(filter_name))
return filter_value
else:
return super(Snapshot, self).get_filter_value(
filter_name, 'DescribeSnapshots')
class EBSBackend(object):
@ -1784,11 +1780,17 @@ class EBSBackend(object):
self.volumes[volume_id] = volume
return volume
def describe_volumes(self, filters=None):
def describe_volumes(self, volume_ids=None, filters=None):
matches = self.volumes.values()
if volume_ids:
matches = [vol for vol in matches
if vol.id in volume_ids]
if len(volume_ids) > len(matches):
unknown_ids = set(volume_ids) - set(matches)
raise InvalidVolumeIdError(unknown_ids)
if filters:
volumes = self.volumes.values()
return generic_filter(filters, volumes)
return self.volumes.values()
matches = generic_filter(filters, matches)
return matches
def get_volume(self, volume_id):
volume = self.volumes.get(volume_id, None)
@ -1836,11 +1838,17 @@ class EBSBackend(object):
self.snapshots[snapshot_id] = snapshot
return snapshot
def describe_snapshots(self, filters=None):
def describe_snapshots(self, snapshot_ids=None, filters=None):
matches = self.snapshots.values()
if snapshot_ids:
matches = [snap for snap in matches
if snap.id in snapshot_ids]
if len(snapshot_ids) > len(matches):
unknown_ids = set(snapshot_ids) - set(matches)
raise InvalidSnapshotIdError(unknown_ids)
if filters:
snapshots = self.snapshots.values()
return generic_filter(filters, snapshots)
return self.snapshots.values()
matches = generic_filter(filters, matches)
return matches
def get_snapshot(self, snapshot_id):
snapshot = self.snapshots.get(snapshot_id, None)
@ -1923,16 +1931,10 @@ class VPC(TaggedEC2Resource):
elif filter_name in ('dhcp-options-id', 'dhcpOptionsId'):
if not self.dhcp_options:
return None
return self.dhcp_options.id
filter_value = super(VPC, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeVPCs".format(filter_name))
return filter_value
else:
return super(VPC, self).get_filter_value(
filter_name, 'DescribeVpcs')
class VPCBackend(object):
@ -1965,12 +1967,16 @@ class VPCBackend(object):
return self.vpcs.get(vpc_id)
def get_all_vpcs(self, vpc_ids=None, filters=None):
matches = self.vpcs.values()
if vpc_ids:
vpcs = [vpc for vpc in self.vpcs.values() if vpc.id in vpc_ids]
else:
vpcs = self.vpcs.values()
return generic_filter(filters, vpcs)
matches = [vpc for vpc in matches
if vpc.id in vpc_ids]
if len(vpc_ids) > len(matches):
unknown_ids = set(vpc_ids) - set(matches)
raise InvalidVPCIdError(unknown_ids)
if filters:
matches = generic_filter(filters, matches)
return matches
def delete_vpc(self, vpc_id):
# Delete route table if only main route table remains.
@ -2166,14 +2172,9 @@ class Subnet(TaggedEC2Resource):
return self.availability_zone
elif filter_name in ('defaultForAz', 'default-for-az'):
return self.default_for_az
filter_value = super(Subnet, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeSubnets".format(filter_name))
return filter_value
else:
return super(Subnet, self).get_filter_value(
filter_name, 'DescribeSubnets')
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -2212,16 +2213,19 @@ class SubnetBackend(object):
return subnet
def get_all_subnets(self, subnet_ids=None, filters=None):
subnets = []
# Extract a list of all subnets
matches = itertools.chain(*[x.values()
for x in self.subnets.values()])
if subnet_ids:
for subnet_id in subnet_ids:
for items in self.subnets.values():
if subnet_id in items:
subnets.append(items[subnet_id])
else:
for items in self.subnets.values():
subnets.extend(items.values())
return generic_filter(filters, subnets)
matches = [sn for sn in matches
if sn.id in subnet_ids]
if len(subnet_ids) > len(matches):
unknown_ids = set(subnet_ids) - set(matches)
raise InvalidSubnetIdError(unknown_ids)
if filters:
matches = generic_filter(filters, matches)
return matches
def delete_subnet(self, subnet_id):
for subnets in self.subnets.values():
@ -2311,14 +2315,9 @@ class RouteTable(TaggedEC2Resource):
return self.associations.keys()
elif filter_name == "association.subnet-id":
return self.associations.values()
filter_value = super(RouteTable, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeRouteTables".format(filter_name))
return filter_value
else:
return super(RouteTable, self).get_filter_value(
filter_name, 'DescribeRouteTables')
class RouteTableBackend(object):
@ -2665,16 +2664,11 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
def get_filter_value(self, filter_name):
if filter_name == 'state':
return self.state
if filter_name == 'spot-instance-request-id':
elif filter_name == 'spot-instance-request-id':
return self.id
filter_value = super(SpotInstanceRequest,
self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeSpotInstanceRequests".format(filter_name))
return filter_value
else:
return super(SpotInstanceRequest, self).get_filter_value(
filter_name, 'DescribeSpotInstanceRequests')
def launch_instance(self):
reservation = self.ec2_backend.add_instances(
@ -2934,6 +2928,25 @@ class ElasticAddress(object):
return self.allocation_id
raise UnformattedGetAttTemplateException()
def get_filter_value(self, filter_name):
if filter_name == 'allocation-id':
return self.allocation_id
elif filter_name == 'association-id':
return self.association_id
elif filter_name == 'domain':
return self.domain
elif filter_name == 'instance-id' and self.instance:
return self.instance.id
elif filter_name == 'network-interface-id' and self.eni:
return self.eni.id
elif filter_name == 'private-ip-address' and self.eni:
return self.eni.private_ip_address
elif filter_name == 'public-ip':
return self.public_ip
else:
# TODO: implement network-interface-owner-id
raise FilterNotImplementedError(filter_name, 'DescribeAddresses')
class ElasticAddressBackend(object):
def __init__(self):
@ -2994,19 +3007,36 @@ class ElasticAddressBackend(object):
if new_instance_association or new_eni_association or reassociate:
eip.instance = instance
eip.eni = eni
if not eip.eni and instance:
# default to primary network interface
eip.eni = instance.nics[0]
if eip.eni:
eip.eni.public_ip = eip.public_ip
if eip.domain == "vpc":
eip.association_id = random_eip_association_id()
if instance:
instance.set_ip(eip.public_ip)
return eip
raise ResourceAlreadyAssociatedError(eip.public_ip)
def describe_addresses(self):
return self.addresses
def describe_addresses(self, allocation_ids=None, public_ips=None, filters=None):
matches = self.addresses
if allocation_ids:
matches = [addr for addr in matches
if addr.allocation_id in allocation_ids]
if len(allocation_ids) > len(matches):
unknown_ids = set(allocation_ids) - set(matches)
raise InvalidAllocationIdError(unknown_ids)
if public_ips:
matches = [addr for addr in matches
if addr.public_ip in public_ips]
if len(public_ips) > len(matches):
unknown_ips = set(allocation_ids) - set(matches)
raise InvalidAddressError(unknown_ips)
if filters:
matches = generic_filter(filters, matches)
return matches
def disassociate_address(self, address=None, association_id=None):
eips = []
@ -3017,10 +3047,9 @@ class ElasticAddressBackend(object):
eip = eips[0]
if eip.eni:
eip.eni.public_ip = None
if eip.eni.instance and eip.eni.instance._state.name == "running":
eip.eni.check_auto_public_ip()
else:
eip.eni.public_ip = None
eip.eni = None
eip.instance = None
@ -3076,15 +3105,9 @@ class DHCPOptionsSet(TaggedEC2Resource):
elif filter_name == 'value':
values = [item for item in list(self._options.values()) if item]
return itertools.chain(*values)
filter_value = super(
DHCPOptionsSet, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeDhcpOptions".format(filter_name))
return filter_value
else:
return super(DHCPOptionsSet, self).get_filter_value(
filter_name, 'DescribeDhcpOptions')
@property
def options(self):
@ -3171,6 +3194,10 @@ class VPNConnection(TaggedEC2Resource):
self.options = None
self.static_routes = None
def get_filter_value(self, filter_name):
return super(VPNConnection, self).get_filter_value(
filter_name, 'DescribeVpnConnections')
class VPNConnectionBackend(object):
def __init__(self):
@ -3350,14 +3377,9 @@ class NetworkAcl(TaggedEC2Resource):
return self.id
elif filter_name == "association.subnet-id":
return [assoc.subnet_id for assoc in self.associations.values()]
filter_value = super(NetworkAcl, self).get_filter_value(filter_name)
if filter_value is None:
self.ec2_backend.raise_not_implemented_error(
"The filter '{0}' for DescribeNetworkAcls".format(filter_name))
return filter_value
else:
return super(NetworkAcl, self).get_filter_value(
filter_name, 'DescribeNetworkAcls')
class NetworkAclEntry(TaggedEC2Resource):
@ -3386,6 +3408,10 @@ class VpnGateway(TaggedEC2Resource):
self.attachments = {}
super(VpnGateway, self).__init__()
def get_filter_value(self, filter_name):
return super(VpnGateway, self).get_filter_value(
filter_name, 'DescribeVpnGateways')
class VpnGatewayAttachment(object):
def __init__(self, vpc_id, state):
@ -3447,6 +3473,10 @@ class CustomerGateway(TaggedEC2Resource):
self.attachments = {}
super(CustomerGateway, self).__init__()
def get_filter_value(self, filter_name):
return super(CustomerGateway, self).get_filter_value(
filter_name, 'DescribeCustomerGateways')
class CustomerGatewayBackend(object):
def __init__(self):
@ -3590,10 +3620,7 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, AmiBackend,
raise EC2ClientError(code, message)
def raise_not_implemented_error(self, blurb):
msg = "{0} has not been implemented in Moto yet." \
" Feel free to open an issue at" \
" https://github.com/spulec/moto/issues".format(blurb)
raise NotImplementedError(msg)
raise MotoNotImplementedError(blurb)
def do_resources_exist(self, resource_ids):
for resource_id in resource_ids:
@ -3640,6 +3667,5 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, AmiBackend,
return True
ec2_backends = {}
for region in RegionsAndZonesBackend.regions:
ec2_backends[region.name] = EC2Backend(region.name)
ec2_backends = {region.name: EC2Backend(region.name)
for region in RegionsAndZonesBackend.regions}

File diff suppressed because one or more lines are too long

View File

@ -1,19 +1,14 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import instance_ids_from_querystring, image_ids_from_querystring, \
filters_from_querystring, sequence_from_querystring, executable_users_from_querystring
from moto.ec2.utils import filters_from_querystring
class AmisResponse(BaseResponse):
def create_image(self):
name = self.querystring.get('Name')[0]
if "Description" in self.querystring:
description = self.querystring.get('Description')[0]
else:
description = ""
instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0]
description = self._get_param('Description', if_none='')
instance_id = self._get_param('InstanceId')
if self.is_not_dryrun('CreateImage'):
image = self.ec2_backend.create_image(
instance_id, name, description)
@ -21,12 +16,10 @@ class AmisResponse(BaseResponse):
return template.render(image=image)
def copy_image(self):
source_image_id = self.querystring.get('SourceImageId')[0]
source_region = self.querystring.get('SourceRegion')[0]
name = self.querystring.get(
'Name')[0] if self.querystring.get('Name') else None
description = self.querystring.get(
'Description')[0] if self.querystring.get('Description') else None
source_image_id = self._get_param('SourceImageId')
source_region = self._get_param('SourceRegion')
name = self._get_param('Name')
description = self._get_param('Description')
if self.is_not_dryrun('CopyImage'):
image = self.ec2_backend.copy_image(
source_image_id, source_region, name, description)
@ -34,33 +27,33 @@ class AmisResponse(BaseResponse):
return template.render(image=image)
def deregister_image(self):
ami_id = self.querystring.get('ImageId')[0]
ami_id = self._get_param('ImageId')
if self.is_not_dryrun('DeregisterImage'):
success = self.ec2_backend.deregister_image(ami_id)
template = self.response_template(DEREGISTER_IMAGE_RESPONSE)
return template.render(success=str(success).lower())
def describe_images(self):
ami_ids = image_ids_from_querystring(self.querystring)
ami_ids = self._get_multi_param('ImageId')
filters = filters_from_querystring(self.querystring)
exec_users = executable_users_from_querystring(self.querystring)
exec_users = self._get_multi_param('ExecutableBy')
images = self.ec2_backend.describe_images(
ami_ids=ami_ids, filters=filters, exec_users=exec_users)
template = self.response_template(DESCRIBE_IMAGES_RESPONSE)
return template.render(images=images)
def describe_image_attribute(self):
ami_id = self.querystring.get('ImageId')[0]
ami_id = self._get_param('ImageId')
groups = self.ec2_backend.get_launch_permission_groups(ami_id)
users = self.ec2_backend.get_launch_permission_users(ami_id)
template = self.response_template(DESCRIBE_IMAGE_ATTRIBUTES_RESPONSE)
return template.render(ami_id=ami_id, groups=groups, users=users)
def modify_image_attribute(self):
ami_id = self.querystring.get('ImageId')[0]
operation_type = self.querystring.get('OperationType')[0]
group = self.querystring.get('UserGroup.1', [None])[0]
user_ids = sequence_from_querystring('UserId', self.querystring)
ami_id = self._get_param('ImageId')
operation_type = self._get_param('OperationType')
group = self._get_param('UserGroup.1')
user_ids = self._get_multi_param('UserId')
if self.is_not_dryrun('ModifyImageAttribute'):
if (operation_type == 'add'):
self.ec2_backend.add_launch_permission(
@ -115,7 +108,7 @@ DESCRIBE_IMAGES_RESPONSE = """<DescribeImagesResponse xmlns="http://ec2.amazonaw
{% endif %}
<description>{{ image.description }}</description>
<rootDeviceType>ebs</rootDeviceType>
<rootDeviceName>/dev/sda</rootDeviceName>
<rootDeviceName>/dev/sda1</rootDeviceName>
<blockDeviceMapping>
<item>
<deviceName>/dev/sda1</deviceName>

View File

@ -7,16 +7,16 @@ class CustomerGateways(BaseResponse):
def create_customer_gateway(self):
# raise NotImplementedError('CustomerGateways(AmazonVPC).create_customer_gateway is not yet implemented')
type = self.querystring.get('Type', None)[0]
ip_address = self.querystring.get('IpAddress', None)[0]
bgp_asn = self.querystring.get('BgpAsn', None)[0]
type = self._get_param('Type')
ip_address = self._get_param('IpAddress')
bgp_asn = self._get_param('BgpAsn')
customer_gateway = self.ec2_backend.create_customer_gateway(
type, ip_address=ip_address, bgp_asn=bgp_asn)
template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=customer_gateway)
def delete_customer_gateway(self):
customer_gateway_id = self.querystring.get('CustomerGatewayId')[0]
customer_gateway_id = self._get_param('CustomerGatewayId')
delete_status = self.ec2_backend.delete_customer_gateway(
customer_gateway_id)
template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE)

View File

@ -2,15 +2,14 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import (
filters_from_querystring,
sequence_from_querystring,
dhcp_configuration_from_querystring)
class DHCPOptions(BaseResponse):
def associate_dhcp_options(self):
dhcp_opt_id = self.querystring.get("DhcpOptionsId", [None])[0]
vpc_id = self.querystring.get("VpcId", [None])[0]
dhcp_opt_id = self._get_param('DhcpOptionsId')
vpc_id = self._get_param('VpcId')
dhcp_opt = self.ec2_backend.describe_dhcp_options([dhcp_opt_id])[0]
vpc = self.ec2_backend.get_vpc(vpc_id)
@ -43,14 +42,13 @@ class DHCPOptions(BaseResponse):
return template.render(dhcp_options_set=dhcp_options_set)
def delete_dhcp_options(self):
dhcp_opt_id = self.querystring.get("DhcpOptionsId", [None])[0]
dhcp_opt_id = self._get_param('DhcpOptionsId')
delete_status = self.ec2_backend.delete_dhcp_options_set(dhcp_opt_id)
template = self.response_template(DELETE_DHCP_OPTIONS_RESPONSE)
return template.render(delete_status=delete_status)
def describe_dhcp_options(self):
dhcp_opt_ids = sequence_from_querystring(
"DhcpOptionsId", self.querystring)
dhcp_opt_ids = self._get_multi_param("DhcpOptionsId")
filters = filters_from_querystring(self.querystring)
dhcp_opts = self.ec2_backend.get_all_dhcp_options(
dhcp_opt_ids, filters)

View File

@ -6,9 +6,9 @@ from moto.ec2.utils import filters_from_querystring
class ElasticBlockStore(BaseResponse):
def attach_volume(self):
volume_id = self.querystring.get('VolumeId')[0]
instance_id = self.querystring.get('InstanceId')[0]
device_path = self.querystring.get('Device')[0]
volume_id = self._get_param('VolumeId')
instance_id = self._get_param('InstanceId')
device_path = self._get_param('Device')
if self.is_not_dryrun('AttachVolume'):
attachment = self.ec2_backend.attach_volume(
volume_id, instance_id, device_path)
@ -21,18 +21,18 @@ class ElasticBlockStore(BaseResponse):
'ElasticBlockStore.copy_snapshot is not yet implemented')
def create_snapshot(self):
description = self.querystring.get('Description', [None])[0]
volume_id = self.querystring.get('VolumeId')[0]
volume_id = self._get_param('VolumeId')
description = self._get_param('Description')
if self.is_not_dryrun('CreateSnapshot'):
snapshot = self.ec2_backend.create_snapshot(volume_id, description)
template = self.response_template(CREATE_SNAPSHOT_RESPONSE)
return template.render(snapshot=snapshot)
def create_volume(self):
size = self.querystring.get('Size', [None])[0]
zone = self.querystring.get('AvailabilityZone', [None])[0]
snapshot_id = self.querystring.get('SnapshotId', [None])[0]
encrypted = self.querystring.get('Encrypted', ['false'])[0]
size = self._get_param('Size')
zone = self._get_param('AvailabilityZone')
snapshot_id = self._get_param('SnapshotId')
encrypted = self._get_param('Encrypted', if_none=False)
if self.is_not_dryrun('CreateVolume'):
volume = self.ec2_backend.create_volume(
size, zone, snapshot_id, encrypted)
@ -40,40 +40,28 @@ class ElasticBlockStore(BaseResponse):
return template.render(volume=volume)
def delete_snapshot(self):
snapshot_id = self.querystring.get('SnapshotId')[0]
snapshot_id = self._get_param('SnapshotId')
if self.is_not_dryrun('DeleteSnapshot'):
self.ec2_backend.delete_snapshot(snapshot_id)
return DELETE_SNAPSHOT_RESPONSE
def delete_volume(self):
volume_id = self.querystring.get('VolumeId')[0]
volume_id = self._get_param('VolumeId')
if self.is_not_dryrun('DeleteVolume'):
self.ec2_backend.delete_volume(volume_id)
return DELETE_VOLUME_RESPONSE
def describe_snapshots(self):
filters = filters_from_querystring(self.querystring)
# querystring for multiple snapshotids results in SnapshotId.1,
# SnapshotId.2 etc
snapshot_ids = ','.join(
[','.join(s[1]) for s in self.querystring.items() if 'SnapshotId' in s[0]])
snapshots = self.ec2_backend.describe_snapshots(filters=filters)
# Describe snapshots to handle filter on snapshot_ids
snapshots = [
s for s in snapshots if s.id in snapshot_ids] if snapshot_ids else snapshots
snapshot_ids = self._get_multi_param('SnapshotId')
snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters)
template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE)
return template.render(snapshots=snapshots)
def describe_volumes(self):
filters = filters_from_querystring(self.querystring)
# querystring for multiple volumeids results in VolumeId.1, VolumeId.2
# etc
volume_ids = ','.join(
[','.join(v[1]) for v in self.querystring.items() if 'VolumeId' in v[0]])
volumes = self.ec2_backend.describe_volumes(filters=filters)
# Describe volumes to handle filter on volume_ids
volumes = [
v for v in volumes if v.id in volume_ids] if volume_ids else volumes
volume_ids = self._get_multi_param('VolumeId')
volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters)
template = self.response_template(DESCRIBE_VOLUMES_RESPONSE)
return template.render(volumes=volumes)
@ -86,9 +74,9 @@ class ElasticBlockStore(BaseResponse):
'ElasticBlockStore.describe_volume_status is not yet implemented')
def detach_volume(self):
volume_id = self.querystring.get('VolumeId')[0]
instance_id = self.querystring.get('InstanceId')[0]
device_path = self.querystring.get('Device')[0]
volume_id = self._get_param('VolumeId')
instance_id = self._get_param('InstanceId')
device_path = self._get_param('Device')
if self.is_not_dryrun('DetachVolume'):
attachment = self.ec2_backend.detach_volume(
volume_id, instance_id, device_path)
@ -106,7 +94,7 @@ class ElasticBlockStore(BaseResponse):
'ElasticBlockStore.import_volume is not yet implemented')
def describe_snapshot_attribute(self):
snapshot_id = self.querystring.get('SnapshotId')[0]
snapshot_id = self._get_param('SnapshotId')
groups = self.ec2_backend.get_create_volume_permission_groups(
snapshot_id)
template = self.response_template(
@ -114,10 +102,10 @@ class ElasticBlockStore(BaseResponse):
return template.render(snapshot_id=snapshot_id, groups=groups)
def modify_snapshot_attribute(self):
snapshot_id = self.querystring.get('SnapshotId')[0]
operation_type = self.querystring.get('OperationType')[0]
group = self.querystring.get('UserGroup.1', [None])[0]
user_id = self.querystring.get('UserId.1', [None])[0]
snapshot_id = self._get_param('SnapshotId')
operation_type = self._get_param('OperationType')
group = self._get_param('UserGroup.1')
user_id = self._get_param('UserId.1')
if self.is_not_dryrun('ModifySnapshotAttribute'):
if (operation_type == 'add'):
self.ec2_backend.add_create_volume_permission(

View File

@ -1,15 +1,12 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import sequence_from_querystring
from moto.ec2.utils import filters_from_querystring
class ElasticIPAddresses(BaseResponse):
def allocate_address(self):
if "Domain" in self.querystring:
domain = self.querystring.get('Domain')[0]
else:
domain = "standard"
domain = self._get_param('Domain', if_none='standard')
if self.is_not_dryrun('AllocateAddress'):
address = self.ec2_backend.allocate_address(domain)
template = self.response_template(ALLOCATE_ADDRESS_RESPONSE)
@ -20,26 +17,28 @@ class ElasticIPAddresses(BaseResponse):
if "InstanceId" in self.querystring:
instance = self.ec2_backend.get_instance(
self.querystring['InstanceId'][0])
self._get_param('InstanceId'))
elif "NetworkInterfaceId" in self.querystring:
eni = self.ec2_backend.get_network_interface(
self.querystring['NetworkInterfaceId'][0])
self._get_param('NetworkInterfaceId'))
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.")
reassociate = False
if "AllowReassociation" in self.querystring:
reassociate = self.querystring['AllowReassociation'][0] == "true"
reassociate = self._get_param('AllowReassociation') == "true"
if self.is_not_dryrun('AssociateAddress'):
if instance or eni:
if "PublicIp" in self.querystring:
eip = self.ec2_backend.associate_address(instance=instance, eni=eni, address=self.querystring[
'PublicIp'][0], reassociate=reassociate)
eip = self.ec2_backend.associate_address(
instance=instance, eni=eni,
address=self._get_param('PublicIp'), reassociate=reassociate)
elif "AllocationId" in self.querystring:
eip = self.ec2_backend.associate_address(instance=instance, eni=eni, allocation_id=self.querystring[
'AllocationId'][0], reassociate=reassociate)
eip = self.ec2_backend.associate_address(
instance=instance, eni=eni,
allocation_id=self._get_param('AllocationId'), reassociate=reassociate)
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.")
@ -51,39 +50,22 @@ class ElasticIPAddresses(BaseResponse):
return template.render(address=eip)
def describe_addresses(self):
allocation_ids = self._get_multi_param('AllocationId')
public_ips = self._get_multi_param('PublicIp')
filters = filters_from_querystring(self.querystring)
addresses = self.ec2_backend.describe_addresses(
allocation_ids, public_ips, filters)
template = self.response_template(DESCRIBE_ADDRESS_RESPONSE)
if "Filter.1.Name" in self.querystring:
filter_by = sequence_from_querystring(
"Filter.1.Name", self.querystring)[0]
filter_value = sequence_from_querystring(
"Filter.1.Value", self.querystring)
if filter_by == 'instance-id':
addresses = filter(lambda x: x.instance.id == filter_value[
0], self.ec2_backend.describe_addresses())
else:
raise NotImplementedError(
"Filtering not supported in describe_address.")
elif "PublicIp.1" in self.querystring:
public_ips = sequence_from_querystring(
"PublicIp", self.querystring)
addresses = self.ec2_backend.address_by_ip(public_ips)
elif "AllocationId.1" in self.querystring:
allocation_ids = sequence_from_querystring(
"AllocationId", self.querystring)
addresses = self.ec2_backend.address_by_allocation(allocation_ids)
else:
addresses = self.ec2_backend.describe_addresses()
return template.render(addresses=addresses)
def disassociate_address(self):
if self.is_not_dryrun('DisAssociateAddress'):
if "PublicIp" in self.querystring:
self.ec2_backend.disassociate_address(
address=self.querystring['PublicIp'][0])
address=self._get_param('PublicIp'))
elif "AssociationId" in self.querystring:
self.ec2_backend.disassociate_address(
association_id=self.querystring['AssociationId'][0])
association_id=self._get_param('AssociationId'))
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.")
@ -94,10 +76,10 @@ class ElasticIPAddresses(BaseResponse):
if self.is_not_dryrun('ReleaseAddress'):
if "PublicIp" in self.querystring:
self.ec2_backend.release_address(
address=self.querystring['PublicIp'][0])
address=self._get_param('PublicIp'))
elif "AllocationId" in self.querystring:
self.ec2_backend.release_address(
allocation_id=self.querystring['AllocationId'][0])
allocation_id=self._get_param('AllocationId'))
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.")

View File

@ -1,15 +1,14 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import sequence_from_querystring, filters_from_querystring
from moto.ec2.utils import filters_from_querystring
class ElasticNetworkInterfaces(BaseResponse):
def create_network_interface(self):
subnet_id = self.querystring.get('SubnetId')[0]
private_ip_address = self.querystring.get(
'PrivateIpAddress', [None])[0]
groups = sequence_from_querystring('SecurityGroupId', self.querystring)
subnet_id = self._get_param('SubnetId')
private_ip_address = self._get_param('PrivateIpAddress')
groups = self._get_multi_param('SecurityGroupId')
subnet = self.ec2_backend.get_subnet(subnet_id)
if self.is_not_dryrun('CreateNetworkInterface'):
eni = self.ec2_backend.create_network_interface(
@ -19,7 +18,7 @@ class ElasticNetworkInterfaces(BaseResponse):
return template.render(eni=eni)
def delete_network_interface(self):
eni_id = self.querystring.get('NetworkInterfaceId')[0]
eni_id = self._get_param('NetworkInterfaceId')
if self.is_not_dryrun('DeleteNetworkInterface'):
self.ec2_backend.delete_network_interface(eni_id)
template = self.response_template(
@ -31,17 +30,16 @@ class ElasticNetworkInterfaces(BaseResponse):
'ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented')
def describe_network_interfaces(self):
eni_ids = sequence_from_querystring(
'NetworkInterfaceId', self.querystring)
eni_ids = self._get_multi_param('NetworkInterfaceId')
filters = filters_from_querystring(self.querystring)
enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters)
template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE)
return template.render(enis=enis)
def attach_network_interface(self):
eni_id = self.querystring.get('NetworkInterfaceId')[0]
instance_id = self.querystring.get('InstanceId')[0]
device_index = self.querystring.get('DeviceIndex')[0]
eni_id = self._get_param('NetworkInterfaceId')
instance_id = self._get_param('InstanceId')
device_index = self._get_param('DeviceIndex')
if self.is_not_dryrun('AttachNetworkInterface'):
attachment_id = self.ec2_backend.attach_network_interface(
eni_id, instance_id, device_index)
@ -50,7 +48,7 @@ class ElasticNetworkInterfaces(BaseResponse):
return template.render(attachment_id=attachment_id)
def detach_network_interface(self):
attachment_id = self.querystring.get('AttachmentId')[0]
attachment_id = self._get_param('AttachmentId')
if self.is_not_dryrun('DetachNetworkInterface'):
self.ec2_backend.detach_network_interface(attachment_id)
template = self.response_template(
@ -59,8 +57,8 @@ class ElasticNetworkInterfaces(BaseResponse):
def modify_network_interface_attribute(self):
# Currently supports modifying one and only one security group
eni_id = self.querystring.get('NetworkInterfaceId')[0]
group_id = self.querystring.get('SecurityGroupId.1')[0]
eni_id = self._get_param('NetworkInterfaceId')
group_id = self._get_param('SecurityGroupId.1')
if self.is_not_dryrun('ModifyNetworkInterface'):
self.ec2_backend.modify_network_interface_attribute(
eni_id, group_id)

View File

@ -1,13 +1,16 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import instance_ids_from_querystring
class General(BaseResponse):
def get_console_output(self):
self.instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = self.instance_ids[0]
instance_id = self._get_param('InstanceId')
if not instance_id:
# For compatibility with boto.
# See: https://github.com/spulec/moto/pull/1152#issuecomment-332487599
instance_id = self._get_multi_param('InstanceId')[0]
instance = self.ec2_backend.get_instance(instance_id)
template = self.response_template(GET_CONSOLE_OUTPUT_RESULT)
return template.render(instance=instance)

View File

@ -2,15 +2,15 @@ from __future__ import unicode_literals
from boto.ec2.instancetype import InstanceType
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import instance_ids_from_querystring, filters_from_querystring, \
dict_from_querystring, optional_from_querystring
from moto.ec2.utils import filters_from_querystring, \
dict_from_querystring
class InstanceResponse(BaseResponse):
def describe_instances(self):
filter_dict = filters_from_querystring(self.querystring)
instance_ids = instance_ids_from_querystring(self.querystring)
instance_ids = self._get_multi_param('InstanceId')
token = self._get_param("NextToken")
if instance_ids:
reservations = self.ec2_backend.get_reservations_by_instance_ids(
@ -33,26 +33,25 @@ class InstanceResponse(BaseResponse):
return template.render(reservations=reservations_resp, next_token=next_token)
def run_instances(self):
min_count = int(self.querystring.get('MinCount', ['1'])[0])
image_id = self.querystring.get('ImageId')[0]
user_data = self.querystring.get('UserData')
min_count = int(self._get_param('MinCount', if_none='1'))
image_id = self._get_param('ImageId')
user_data = self._get_param('UserData')
security_group_names = self._get_multi_param('SecurityGroup')
security_group_ids = self._get_multi_param('SecurityGroupId')
nics = dict_from_querystring("NetworkInterface", self.querystring)
instance_type = self.querystring.get("InstanceType", ["m1.small"])[0]
placement = self.querystring.get(
"Placement.AvailabilityZone", [None])[0]
subnet_id = self.querystring.get("SubnetId", [None])[0]
private_ip = self.querystring.get("PrivateIpAddress", [None])[0]
associate_public_ip = self.querystring.get(
"AssociatePublicIpAddress", [None])[0]
key_name = self.querystring.get("KeyName", [None])[0]
instance_type = self._get_param('InstanceType', if_none='m1.small')
placement = self._get_param('Placement.AvailabilityZone')
subnet_id = self._get_param('SubnetId')
private_ip = self._get_param('PrivateIpAddress')
associate_public_ip = self._get_param('AssociatePublicIpAddress')
key_name = self._get_param('KeyName')
tags = self._parse_tag_specification("TagSpecification")
region_name = self.region
if self.is_not_dryrun('RunInstance'):
new_reservation = self.ec2_backend.add_instances(
image_id, min_count, user_data, security_group_names,
instance_type=instance_type, placement=placement, subnet_id=subnet_id,
instance_type=instance_type, placement=placement, region_name=region_name, subnet_id=subnet_id,
key_name=key_name, security_group_ids=security_group_ids,
nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip,
tags=tags)
@ -61,37 +60,36 @@ class InstanceResponse(BaseResponse):
return template.render(reservation=new_reservation)
def terminate_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring)
instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('TerminateInstance'):
instances = self.ec2_backend.terminate_instances(instance_ids)
template = self.response_template(EC2_TERMINATE_INSTANCES)
return template.render(instances=instances)
def reboot_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring)
instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('RebootInstance'):
instances = self.ec2_backend.reboot_instances(instance_ids)
template = self.response_template(EC2_REBOOT_INSTANCES)
return template.render(instances=instances)
def stop_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring)
instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('StopInstance'):
instances = self.ec2_backend.stop_instances(instance_ids)
template = self.response_template(EC2_STOP_INSTANCES)
return template.render(instances=instances)
def start_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring)
instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('StartInstance'):
instances = self.ec2_backend.start_instances(instance_ids)
template = self.response_template(EC2_START_INSTANCES)
return template.render(instances=instances)
def describe_instance_status(self):
instance_ids = instance_ids_from_querystring(self.querystring)
include_all_instances = optional_from_querystring('IncludeAllInstances',
self.querystring) == 'true'
instance_ids = self._get_multi_param('InstanceId')
include_all_instances = self._get_param('IncludeAllInstances') == 'true'
if instance_ids:
instances = self.ec2_backend.get_multi_instances_by_id(
@ -113,10 +111,9 @@ class InstanceResponse(BaseResponse):
def describe_instance_attribute(self):
# TODO this and modify below should raise IncorrectInstanceState if
# instance not in stopped state
attribute = self.querystring.get("Attribute")[0]
attribute = self._get_param('Attribute')
key = camelcase_to_underscores(attribute)
instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0]
instance_id = self._get_param('InstanceId')
instance, value = self.ec2_backend.describe_instance_attribute(
instance_id, key)
@ -170,8 +167,7 @@ class InstanceResponse(BaseResponse):
del_on_term_value = True if 'true' == del_on_term_value_str else False
device_name_value = self.querystring[mapping_device_name][0]
instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0]
instance_id = self._get_param('InstanceId')
instance = self.ec2_backend.get_instance(instance_id)
if self.is_not_dryrun('ModifyInstanceAttribute'):
@ -199,8 +195,7 @@ class InstanceResponse(BaseResponse):
value = self.querystring.get(attribute_key)[0]
normalized_attribute = camelcase_to_underscores(
attribute_key.split(".")[0])
instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0]
instance_id = self._get_param('InstanceId')
self.ec2_backend.modify_instance_attribute(
instance_id, normalized_attribute, value)
return EC2_MODIFY_INSTANCE_ATTRIBUTE
@ -211,8 +206,7 @@ class InstanceResponse(BaseResponse):
if 'GroupId.' in key:
new_security_grp_list.append(self.querystring.get(key)[0])
instance_ids = instance_ids_from_querystring(self.querystring)
instance_id = instance_ids[0]
instance_id = self._get_param('InstanceId')
if self.is_not_dryrun('ModifyInstanceSecurityGroups'):
self.ec2_backend.modify_instance_security_groups(
instance_id, new_security_grp_list)
@ -254,17 +248,19 @@ EC2_RUN_INSTANCES = """<RunInstancesResponse xmlns="http://ec2.amazonaws.com/doc
<monitoring>
<state>enabled</state>
</monitoring>
{% if instance.nics %}
{% if instance.nics[0].subnet %}
<subnetId>{{ instance.nics[0].subnet.id }}</subnetId>
<vpcId>{{ instance.nics[0].subnet.vpc_id }}</vpcId>
{% endif %}
<privateIpAddress>{{ instance.private_ip }}</privateIpAddress>
{% if instance.public_ip %}
<ipAddress>{{ instance.public_ip }}</ipAddress>
{% endif %}
{% else %}
{% if instance.subnet_id %}
<subnetId>{{ instance.subnet_id }}</subnetId>
{% elif instance.nics[0].subnet.id %}
<subnetId>{{ instance.nics[0].subnet.id }}</subnetId>
{% endif %}
{% if instance.vpc_id %}
<vpcId>{{ instance.vpc_id }}</vpcId>
{% elif instance.nics[0].subnet.vpc_id %}
<vpcId>{{ instance.nics[0].subnet.vpc_id }}</vpcId>
{% endif %}
<privateIpAddress>{{ instance.private_ip }}</privateIpAddress>
{% if instance.nics[0].public_ip %}
<ipAddress>{{ instance.nics[0].public_ip }}</ipAddress>
{% endif %}
<sourceDestCheck>{{ instance.source_dest_check }}</sourceDestCheck>
<groupSet>
@ -395,26 +391,30 @@ EC2_DESCRIBE_INSTANCES = """<DescribeInstancesResponse xmlns="http://ec2.amazona
<monitoring>
<state>disabled</state>
</monitoring>
{% if instance.nics %}
{% if instance.nics[0].subnet %}
<subnetId>{{ instance.nics[0].subnet.id }}</subnetId>
<vpcId>{{ instance.nics[0].subnet.vpc_id }}</vpcId>
{% endif %}
<privateIpAddress>{{ instance.private_ip }}</privateIpAddress>
{% if instance.nics[0].public_ip %}
<ipAddress>{{ instance.nics[0].public_ip }}</ipAddress>
{% endif %}
{% if instance.subnet_id %}
<subnetId>{{ instance.subnet_id }}</subnetId>
{% elif instance.nics[0].subnet.id %}
<subnetId>{{ instance.nics[0].subnet.id }}</subnetId>
{% endif %}
{% if instance.vpc_id %}
<vpcId>{{ instance.vpc_id }}</vpcId>
{% elif instance.nics[0].subnet.vpc_id %}
<vpcId>{{ instance.nics[0].subnet.vpc_id }}</vpcId>
{% endif %}
<privateIpAddress>{{ instance.private_ip }}</privateIpAddress>
{% if instance.nics[0].public_ip %}
<ipAddress>{{ instance.nics[0].public_ip }}</ipAddress>
{% endif %}
<sourceDestCheck>{{ instance.source_dest_check }}</sourceDestCheck>
<groupSet>
{% for group in instance.dynamic_group_list %}
<item>
{% if group.id %}
<groupId>{{ group.id }}</groupId>
<groupName>{{ group.name }}</groupName>
{% else %}
<groupId>{{ group }}</groupId>
{% endif %}
{% if group.id %}
<groupId>{{ group.id }}</groupId>
<groupName>{{ group.name }}</groupName>
{% else %}
<groupId>{{ group }}</groupId>
{% endif %}
</item>
{% endfor %}
</groupSet>

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import (
sequence_from_querystring,
filters_from_querystring,
)
@ -9,8 +8,8 @@ from moto.ec2.utils import (
class InternetGateways(BaseResponse):
def attach_internet_gateway(self):
igw_id = self.querystring.get("InternetGatewayId", [None])[0]
vpc_id = self.querystring.get("VpcId", [None])[0]
igw_id = self._get_param('InternetGatewayId')
vpc_id = self._get_param('VpcId')
if self.is_not_dryrun('AttachInternetGateway'):
self.ec2_backend.attach_internet_gateway(igw_id, vpc_id)
template = self.response_template(ATTACH_INTERNET_GATEWAY_RESPONSE)
@ -23,7 +22,7 @@ class InternetGateways(BaseResponse):
return template.render(internet_gateway=igw)
def delete_internet_gateway(self):
igw_id = self.querystring.get("InternetGatewayId", [None])[0]
igw_id = self._get_param('InternetGatewayId')
if self.is_not_dryrun('DeleteInternetGateway'):
self.ec2_backend.delete_internet_gateway(igw_id)
template = self.response_template(DELETE_INTERNET_GATEWAY_RESPONSE)
@ -32,8 +31,7 @@ class InternetGateways(BaseResponse):
def describe_internet_gateways(self):
filter_dict = filters_from_querystring(self.querystring)
if "InternetGatewayId.1" in self.querystring:
igw_ids = sequence_from_querystring(
"InternetGatewayId", self.querystring)
igw_ids = self._get_multi_param("InternetGatewayId")
igws = self.ec2_backend.describe_internet_gateways(
igw_ids, filters=filter_dict)
else:
@ -46,8 +44,8 @@ class InternetGateways(BaseResponse):
def detach_internet_gateway(self):
# TODO validate no instances with EIPs in VPC before detaching
# raise else DependencyViolationError()
igw_id = self.querystring.get("InternetGatewayId", [None])[0]
vpc_id = self.querystring.get("VpcId", [None])[0]
igw_id = self._get_param('InternetGatewayId')
vpc_id = self._get_param('VpcId')
if self.is_not_dryrun('DetachInternetGateway'):
self.ec2_backend.detach_internet_gateway(igw_id, vpc_id)
template = self.response_template(DETACH_INTERNET_GATEWAY_RESPONSE)

View File

@ -1,43 +1,39 @@
from __future__ import unicode_literals
import six
from moto.core.responses import BaseResponse
from moto.ec2.utils import keypair_names_from_querystring, filters_from_querystring
from moto.ec2.utils import filters_from_querystring
class KeyPairs(BaseResponse):
def create_key_pair(self):
name = self.querystring.get('KeyName')[0]
name = self._get_param('KeyName')
if self.is_not_dryrun('CreateKeyPair'):
keypair = self.ec2_backend.create_key_pair(name)
template = self.response_template(CREATE_KEY_PAIR_RESPONSE)
return template.render(**keypair)
return template.render(keypair=keypair)
def delete_key_pair(self):
name = self.querystring.get('KeyName')[0]
name = self._get_param('KeyName')
if self.is_not_dryrun('DeleteKeyPair'):
success = six.text_type(
self.ec2_backend.delete_key_pair(name)).lower()
return self.response_template(DELETE_KEY_PAIR_RESPONSE).render(success=success)
def describe_key_pairs(self):
names = keypair_names_from_querystring(self.querystring)
names = self._get_multi_param('KeyName')
filters = filters_from_querystring(self.querystring)
if len(filters) > 0:
raise NotImplementedError(
'Using filters in KeyPairs.describe_key_pairs is not yet implemented')
keypairs = self.ec2_backend.describe_key_pairs(names)
keypairs = self.ec2_backend.describe_key_pairs(names, filters)
template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE)
return template.render(keypairs=keypairs)
def import_key_pair(self):
name = self.querystring.get('KeyName')[0]
material = self.querystring.get('PublicKeyMaterial')[0]
name = self._get_param('KeyName')
material = self._get_param('PublicKeyMaterial')
if self.is_not_dryrun('ImportKeyPair'):
keypair = self.ec2_backend.import_key_pair(name, material)
template = self.response_template(IMPORT_KEYPAIR_RESPONSE)
return template.render(**keypair)
return template.render(keypair=keypair)
DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -54,12 +50,9 @@ DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.ama
CREATE_KEY_PAIR_RESPONSE = """<CreateKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<keyName>{{ name }}</keyName>
<keyFingerprint>
{{ fingerprint }}
</keyFingerprint>
<keyMaterial>{{ material }}
</keyMaterial>
<keyName>{{ keypair.name }}</keyName>
<keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
<keyMaterial>{{ keypair.material }}</keyMaterial>
</CreateKeyPairResponse>"""
@ -71,6 +64,6 @@ DELETE_KEY_PAIR_RESPONSE = """<DeleteKeyPairResponse xmlns="http://ec2.amazonaws
IMPORT_KEYPAIR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ImportKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>471f9fdd-8fe2-4a84-86b0-bd3d3e350979</requestId>
<keyName>{{ name }}</keyName>
<keyFingerprint>{{ fingerprint }}</keyFingerprint>
<keyName>{{ keypair.name }}</keyName>
<keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
</ImportKeyPairResponse>"""

View File

@ -1,28 +1,27 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring, \
network_acl_ids_from_querystring
from moto.ec2.utils import filters_from_querystring
class NetworkACLs(BaseResponse):
def create_network_acl(self):
vpc_id = self.querystring.get('VpcId')[0]
vpc_id = self._get_param('VpcId')
network_acl = self.ec2_backend.create_network_acl(vpc_id)
template = self.response_template(CREATE_NETWORK_ACL_RESPONSE)
return template.render(network_acl=network_acl)
def create_network_acl_entry(self):
network_acl_id = self.querystring.get('NetworkAclId')[0]
rule_number = self.querystring.get('RuleNumber')[0]
protocol = self.querystring.get('Protocol')[0]
rule_action = self.querystring.get('RuleAction')[0]
egress = self.querystring.get('Egress')[0]
cidr_block = self.querystring.get('CidrBlock')[0]
icmp_code = self.querystring.get('Icmp.Code', [None])[0]
icmp_type = self.querystring.get('Icmp.Type', [None])[0]
port_range_from = self.querystring.get('PortRange.From')[0]
port_range_to = self.querystring.get('PortRange.To')[0]
network_acl_id = self._get_param('NetworkAclId')
rule_number = self._get_param('RuleNumber')
protocol = self._get_param('Protocol')
rule_action = self._get_param('RuleAction')
egress = self._get_param('Egress')
cidr_block = self._get_param('CidrBlock')
icmp_code = self._get_param('Icmp.Code')
icmp_type = self._get_param('Icmp.Type')
port_range_from = self._get_param('PortRange.From')
port_range_to = self._get_param('PortRange.To')
network_acl_entry = self.ec2_backend.create_network_acl_entry(
network_acl_id, rule_number, protocol, rule_action,
@ -33,30 +32,30 @@ class NetworkACLs(BaseResponse):
return template.render(network_acl_entry=network_acl_entry)
def delete_network_acl(self):
network_acl_id = self.querystring.get('NetworkAclId')[0]
network_acl_id = self._get_param('NetworkAclId')
self.ec2_backend.delete_network_acl(network_acl_id)
template = self.response_template(DELETE_NETWORK_ACL_ASSOCIATION)
return template.render()
def delete_network_acl_entry(self):
network_acl_id = self.querystring.get('NetworkAclId')[0]
rule_number = self.querystring.get('RuleNumber')[0]
egress = self.querystring.get('Egress')[0]
network_acl_id = self._get_param('NetworkAclId')
rule_number = self._get_param('RuleNumber')
egress = self._get_param('Egress')
self.ec2_backend.delete_network_acl_entry(network_acl_id, rule_number, egress)
template = self.response_template(DELETE_NETWORK_ACL_ENTRY_RESPONSE)
return template.render()
def replace_network_acl_entry(self):
network_acl_id = self.querystring.get('NetworkAclId')[0]
rule_number = self.querystring.get('RuleNumber')[0]
protocol = self.querystring.get('Protocol')[0]
rule_action = self.querystring.get('RuleAction')[0]
egress = self.querystring.get('Egress')[0]
cidr_block = self.querystring.get('CidrBlock')[0]
icmp_code = self.querystring.get('Icmp.Code', [None])[0]
icmp_type = self.querystring.get('Icmp.Type', [None])[0]
port_range_from = self.querystring.get('PortRange.From')[0]
port_range_to = self.querystring.get('PortRange.To')[0]
network_acl_id = self._get_param('NetworkAclId')
rule_number = self._get_param('RuleNumber')
protocol = self._get_param('Protocol')
rule_action = self._get_param('RuleAction')
egress = self._get_param('Egress')
cidr_block = self._get_param('CidrBlock')
icmp_code = self._get_param('Icmp.Code')
icmp_type = self._get_param('Icmp.Type')
port_range_from = self._get_param('PortRange.From')
port_range_to = self._get_param('PortRange.To')
self.ec2_backend.replace_network_acl_entry(
network_acl_id, rule_number, protocol, rule_action,
@ -67,7 +66,7 @@ class NetworkACLs(BaseResponse):
return template.render()
def describe_network_acls(self):
network_acl_ids = network_acl_ids_from_querystring(self.querystring)
network_acl_ids = self._get_multi_param('NetworkAclId')
filters = filters_from_querystring(self.querystring)
network_acls = self.ec2_backend.get_all_network_acls(
network_acl_ids, filters)
@ -75,8 +74,8 @@ class NetworkACLs(BaseResponse):
return template.render(network_acls=network_acls)
def replace_network_acl_association(self):
association_id = self.querystring.get('AssociationId')[0]
network_acl_id = self.querystring.get('NetworkAclId')[0]
association_id = self._get_param('AssociationId')
network_acl_id = self._get_param('NetworkAclId')
association = self.ec2_backend.replace_network_acl_association(
association_id,

View File

@ -1,29 +1,25 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import route_table_ids_from_querystring, filters_from_querystring, optional_from_querystring
from moto.ec2.utils import filters_from_querystring
class RouteTables(BaseResponse):
def associate_route_table(self):
route_table_id = self.querystring.get('RouteTableId')[0]
subnet_id = self.querystring.get('SubnetId')[0]
route_table_id = self._get_param('RouteTableId')
subnet_id = self._get_param('SubnetId')
association_id = self.ec2_backend.associate_route_table(
route_table_id, subnet_id)
template = self.response_template(ASSOCIATE_ROUTE_TABLE_RESPONSE)
return template.render(association_id=association_id)
def create_route(self):
route_table_id = self.querystring.get('RouteTableId')[0]
destination_cidr_block = self.querystring.get(
'DestinationCidrBlock')[0]
gateway_id = optional_from_querystring('GatewayId', self.querystring)
instance_id = optional_from_querystring('InstanceId', self.querystring)
interface_id = optional_from_querystring(
'NetworkInterfaceId', self.querystring)
pcx_id = optional_from_querystring(
'VpcPeeringConnectionId', self.querystring)
route_table_id = self._get_param('RouteTableId')
destination_cidr_block = self._get_param('DestinationCidrBlock')
gateway_id = self._get_param('GatewayId')
instance_id = self._get_param('InstanceId')
interface_id = self._get_param('NetworkInterfaceId')
pcx_id = self._get_param('VpcPeeringConnectionId')
self.ec2_backend.create_route(route_table_id, destination_cidr_block,
gateway_id=gateway_id,
@ -35,27 +31,26 @@ class RouteTables(BaseResponse):
return template.render()
def create_route_table(self):
vpc_id = self.querystring.get('VpcId')[0]
vpc_id = self._get_param('VpcId')
route_table = self.ec2_backend.create_route_table(vpc_id)
template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE)
return template.render(route_table=route_table)
def delete_route(self):
route_table_id = self.querystring.get('RouteTableId')[0]
destination_cidr_block = self.querystring.get(
'DestinationCidrBlock')[0]
route_table_id = self._get_param('RouteTableId')
destination_cidr_block = self._get_param('DestinationCidrBlock')
self.ec2_backend.delete_route(route_table_id, destination_cidr_block)
template = self.response_template(DELETE_ROUTE_RESPONSE)
return template.render()
def delete_route_table(self):
route_table_id = self.querystring.get('RouteTableId')[0]
route_table_id = self._get_param('RouteTableId')
self.ec2_backend.delete_route_table(route_table_id)
template = self.response_template(DELETE_ROUTE_TABLE_RESPONSE)
return template.render()
def describe_route_tables(self):
route_table_ids = route_table_ids_from_querystring(self.querystring)
route_table_ids = self._get_multi_param('RouteTableId')
filters = filters_from_querystring(self.querystring)
route_tables = self.ec2_backend.get_all_route_tables(
route_table_ids, filters)
@ -63,22 +58,18 @@ class RouteTables(BaseResponse):
return template.render(route_tables=route_tables)
def disassociate_route_table(self):
association_id = self.querystring.get('AssociationId')[0]
association_id = self._get_param('AssociationId')
self.ec2_backend.disassociate_route_table(association_id)
template = self.response_template(DISASSOCIATE_ROUTE_TABLE_RESPONSE)
return template.render()
def replace_route(self):
route_table_id = self.querystring.get('RouteTableId')[0]
destination_cidr_block = self.querystring.get(
'DestinationCidrBlock')[0]
gateway_id = optional_from_querystring('GatewayId', self.querystring)
instance_id = optional_from_querystring('InstanceId', self.querystring)
interface_id = optional_from_querystring(
'NetworkInterfaceId', self.querystring)
pcx_id = optional_from_querystring(
'VpcPeeringConnectionId', self.querystring)
route_table_id = self._get_param('RouteTableId')
destination_cidr_block = self._get_param('DestinationCidrBlock')
gateway_id = self._get_param('GatewayId')
instance_id = self._get_param('InstanceId')
interface_id = self._get_param('NetworkInterfaceId')
pcx_id = self._get_param('VpcPeeringConnectionId')
self.ec2_backend.replace_route(route_table_id, destination_cidr_block,
gateway_id=gateway_id,
@ -90,8 +81,8 @@ class RouteTables(BaseResponse):
return template.render()
def replace_route_table_association(self):
route_table_id = self.querystring.get('RouteTableId')[0]
association_id = self.querystring.get('AssociationId')[0]
route_table_id = self._get_param('RouteTableId')
association_id = self._get_param('AssociationId')
new_association_id = self.ec2_backend.replace_route_table_association(
association_id, route_table_id)
template = self.response_template(

View File

@ -11,69 +11,66 @@ def try_parse_int(value, default=None):
return default
def process_rules_from_querystring(querystring):
try:
group_name_or_id = querystring.get('GroupName')[0]
except:
group_name_or_id = querystring.get('GroupId')[0]
querytree = {}
for key, value in querystring.items():
key_splitted = key.split('.')
key_splitted = [try_parse_int(e, e) for e in key_splitted]
d = querytree
for subkey in key_splitted[:-1]:
if subkey not in d:
d[subkey] = {}
d = d[subkey]
d[key_splitted[-1]] = value
ip_permissions = querytree.get('IpPermissions') or {}
for ip_permission_idx in sorted(ip_permissions.keys()):
ip_permission = ip_permissions[ip_permission_idx]
ip_protocol = ip_permission.get('IpProtocol', [None])[0]
from_port = ip_permission.get('FromPort', [None])[0]
to_port = ip_permission.get('ToPort', [None])[0]
ip_ranges = []
ip_ranges_tree = ip_permission.get('IpRanges') or {}
for ip_range_idx in sorted(ip_ranges_tree.keys()):
ip_ranges.append(ip_ranges_tree[ip_range_idx]['CidrIp'][0])
source_groups = []
source_group_ids = []
groups_tree = ip_permission.get('Groups') or {}
for group_idx in sorted(groups_tree.keys()):
group_dict = groups_tree[group_idx]
if 'GroupId' in group_dict:
source_group_ids.append(group_dict['GroupId'][0])
elif 'GroupName' in group_dict:
source_groups.append(group_dict['GroupName'][0])
yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges,
source_groups, source_group_ids)
class SecurityGroups(BaseResponse):
def _process_rules_from_querystring(self):
group_name_or_id = (self._get_param('GroupName') or
self._get_param('GroupId'))
querytree = {}
for key, value in self.querystring.items():
key_splitted = key.split('.')
key_splitted = [try_parse_int(e, e) for e in key_splitted]
d = querytree
for subkey in key_splitted[:-1]:
if subkey not in d:
d[subkey] = {}
d = d[subkey]
d[key_splitted[-1]] = value
ip_permissions = querytree.get('IpPermissions') or {}
for ip_permission_idx in sorted(ip_permissions.keys()):
ip_permission = ip_permissions[ip_permission_idx]
ip_protocol = ip_permission.get('IpProtocol', [None])[0]
from_port = ip_permission.get('FromPort', [None])[0]
to_port = ip_permission.get('ToPort', [None])[0]
ip_ranges = []
ip_ranges_tree = ip_permission.get('IpRanges') or {}
for ip_range_idx in sorted(ip_ranges_tree.keys()):
ip_ranges.append(ip_ranges_tree[ip_range_idx]['CidrIp'][0])
source_groups = []
source_group_ids = []
groups_tree = ip_permission.get('Groups') or {}
for group_idx in sorted(groups_tree.keys()):
group_dict = groups_tree[group_idx]
if 'GroupId' in group_dict:
source_group_ids.append(group_dict['GroupId'][0])
elif 'GroupName' in group_dict:
source_groups.append(group_dict['GroupName'][0])
yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges,
source_groups, source_group_ids)
def authorize_security_group_egress(self):
if self.is_not_dryrun('GrantSecurityGroupEgress'):
for args in process_rules_from_querystring(self.querystring):
for args in self._process_rules_from_querystring():
self.ec2_backend.authorize_security_group_egress(*args)
return AUTHORIZE_SECURITY_GROUP_EGRESS_RESPONSE
def authorize_security_group_ingress(self):
if self.is_not_dryrun('GrantSecurityGroupIngress'):
for args in process_rules_from_querystring(self.querystring):
for args in self._process_rules_from_querystring():
self.ec2_backend.authorize_security_group_ingress(*args)
return AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE
def create_security_group(self):
name = self.querystring.get('GroupName')[0]
description = self.querystring.get('GroupDescription', [None])[0]
vpc_id = self.querystring.get("VpcId", [None])[0]
name = self._get_param('GroupName')
description = self._get_param('GroupDescription')
vpc_id = self._get_param('VpcId')
if self.is_not_dryrun('CreateSecurityGroup'):
group = self.ec2_backend.create_security_group(
@ -86,14 +83,14 @@ class SecurityGroups(BaseResponse):
# See
# http://docs.aws.amazon.com/AWSEC2/latest/APIReference/ApiReference-query-DeleteSecurityGroup.html
name = self.querystring.get('GroupName')
sg_id = self.querystring.get('GroupId')
name = self._get_param('GroupName')
sg_id = self._get_param('GroupId')
if self.is_not_dryrun('DeleteSecurityGroup'):
if name:
self.ec2_backend.delete_security_group(name[0])
self.ec2_backend.delete_security_group(name)
elif sg_id:
self.ec2_backend.delete_security_group(group_id=sg_id[0])
self.ec2_backend.delete_security_group(group_id=sg_id)
return DELETE_GROUP_RESPONSE
@ -113,7 +110,7 @@ class SecurityGroups(BaseResponse):
def revoke_security_group_egress(self):
if self.is_not_dryrun('RevokeSecurityGroupEgress'):
for args in process_rules_from_querystring(self.querystring):
for args in self._process_rules_from_querystring():
success = self.ec2_backend.revoke_security_group_egress(*args)
if not success:
return "Could not find a matching egress rule", dict(status=404)
@ -121,7 +118,7 @@ class SecurityGroups(BaseResponse):
def revoke_security_group_ingress(self):
if self.is_not_dryrun('RevokeSecurityGroupIngress'):
for args in process_rules_from_querystring(self.querystring):
for args in self._process_rules_from_querystring():
self.ec2_backend.revoke_security_group_ingress(*args)
return REVOKE_SECURITY_GROUP_INGRESS_REPONSE

View File

@ -7,14 +7,11 @@ from moto.ec2.utils import filters_from_querystring
class Subnets(BaseResponse):
def create_subnet(self):
vpc_id = self.querystring.get('VpcId')[0]
cidr_block = self.querystring.get('CidrBlock')[0]
if 'AvailabilityZone' in self.querystring:
availability_zone = self.querystring['AvailabilityZone'][0]
else:
zone = random.choice(
self.ec2_backend.describe_availability_zones())
availability_zone = zone.name
vpc_id = self._get_param('VpcId')
cidr_block = self._get_param('CidrBlock')
availability_zone = self._get_param(
'AvailabilityZone', if_none=random.choice(
self.ec2_backend.describe_availability_zones()).name)
subnet = self.ec2_backend.create_subnet(
vpc_id,
cidr_block,
@ -24,30 +21,21 @@ class Subnets(BaseResponse):
return template.render(subnet=subnet)
def delete_subnet(self):
subnet_id = self.querystring.get('SubnetId')[0]
subnet_id = self._get_param('SubnetId')
subnet = self.ec2_backend.delete_subnet(subnet_id)
template = self.response_template(DELETE_SUBNET_RESPONSE)
return template.render(subnet=subnet)
def describe_subnets(self):
subnet_ids = self._get_multi_param('SubnetId')
filters = filters_from_querystring(self.querystring)
subnet_ids = []
idx = 1
key = 'SubnetId.{0}'.format(idx)
while key in self.querystring:
v = self.querystring[key]
subnet_ids.append(v[0])
idx += 1
key = 'SubnetId.{0}'.format(idx)
subnets = self.ec2_backend.get_all_subnets(subnet_ids, filters)
template = self.response_template(DESCRIBE_SUBNETS_RESPONSE)
return template.render(subnets=subnets)
def modify_subnet_attribute(self):
subnet_id = self.querystring.get('SubnetId')[0]
map_public_ip = self.querystring.get('MapPublicIpOnLaunch.Value')[0]
subnet_id = self._get_param('SubnetId')
map_public_ip = self._get_param('MapPublicIpOnLaunch.Value')
self.ec2_backend.modify_subnet_attribute(subnet_id, map_public_ip)
return MODIFY_SUBNET_ATTRIBUTE_RESPONSE

View File

@ -2,14 +2,13 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.models import validate_resource_ids
from moto.ec2.utils import sequence_from_querystring, tags_from_query_string, filters_from_querystring
from moto.ec2.utils import tags_from_query_string, filters_from_querystring
class TagResponse(BaseResponse):
def create_tags(self):
resource_ids = sequence_from_querystring(
'ResourceId', self.querystring)
resource_ids = self._get_multi_param('ResourceId')
validate_resource_ids(resource_ids)
self.ec2_backend.do_resources_exist(resource_ids)
tags = tags_from_query_string(self.querystring)
@ -18,8 +17,7 @@ class TagResponse(BaseResponse):
return CREATE_RESPONSE
def delete_tags(self):
resource_ids = sequence_from_querystring(
'ResourceId', self.querystring)
resource_ids = self._get_multi_param('ResourceId')
validate_resource_ids(resource_ids)
tags = tags_from_query_string(self.querystring)
if self.is_not_dryrun('DeleteTags'):

View File

@ -6,8 +6,8 @@ from moto.ec2.utils import filters_from_querystring
class VirtualPrivateGateways(BaseResponse):
def attach_vpn_gateway(self):
vpn_gateway_id = self.querystring.get('VpnGatewayId')[0]
vpc_id = self.querystring.get('VpcId')[0]
vpn_gateway_id = self._get_param('VpnGatewayId')
vpc_id = self._get_param('VpcId')
attachment = self.ec2_backend.attach_vpn_gateway(
vpn_gateway_id,
vpc_id
@ -16,13 +16,13 @@ class VirtualPrivateGateways(BaseResponse):
return template.render(attachment=attachment)
def create_vpn_gateway(self):
type = self.querystring.get('Type', None)[0]
type = self._get_param('Type')
vpn_gateway = self.ec2_backend.create_vpn_gateway(type)
template = self.response_template(CREATE_VPN_GATEWAY_RESPONSE)
return template.render(vpn_gateway=vpn_gateway)
def delete_vpn_gateway(self):
vpn_gateway_id = self.querystring.get('VpnGatewayId')[0]
vpn_gateway_id = self._get_param('VpnGatewayId')
vpn_gateway = self.ec2_backend.delete_vpn_gateway(vpn_gateway_id)
template = self.response_template(DELETE_VPN_GATEWAY_RESPONSE)
return template.render(vpn_gateway=vpn_gateway)
@ -34,8 +34,8 @@ class VirtualPrivateGateways(BaseResponse):
return template.render(vpn_gateways=vpn_gateways)
def detach_vpn_gateway(self):
vpn_gateway_id = self.querystring.get('VpnGatewayId')[0]
vpc_id = self.querystring.get('VpcId')[0]
vpn_gateway_id = self._get_param('VpnGatewayId')
vpc_id = self._get_param('VpcId')
attachment = self.ec2_backend.detach_vpn_gateway(
vpn_gateway_id,
vpc_id

View File

@ -5,16 +5,15 @@ from moto.core.responses import BaseResponse
class VPCPeeringConnections(BaseResponse):
def create_vpc_peering_connection(self):
vpc = self.ec2_backend.get_vpc(self.querystring.get('VpcId')[0])
peer_vpc = self.ec2_backend.get_vpc(
self.querystring.get('PeerVpcId')[0])
vpc = self.ec2_backend.get_vpc(self._get_param('VpcId'))
peer_vpc = self.ec2_backend.get_vpc(self._get_param('PeerVpcId'))
vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc)
template = self.response_template(
CREATE_VPC_PEERING_CONNECTION_RESPONSE)
return template.render(vpc_pcx=vpc_pcx)
def delete_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0]
vpc_pcx_id = self._get_param('VpcPeeringConnectionId')
vpc_pcx = self.ec2_backend.delete_vpc_peering_connection(vpc_pcx_id)
template = self.response_template(
DELETE_VPC_PEERING_CONNECTION_RESPONSE)
@ -27,14 +26,14 @@ class VPCPeeringConnections(BaseResponse):
return template.render(vpc_pcxs=vpc_pcxs)
def accept_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0]
vpc_pcx_id = self._get_param('VpcPeeringConnectionId')
vpc_pcx = self.ec2_backend.accept_vpc_peering_connection(vpc_pcx_id)
template = self.response_template(
ACCEPT_VPC_PEERING_CONNECTION_RESPONSE)
return template.render(vpc_pcx=vpc_pcx)
def reject_vpc_peering_connection(self):
vpc_pcx_id = self.querystring.get('VpcPeeringConnectionId')[0]
vpc_pcx_id = self._get_param('VpcPeeringConnectionId')
self.ec2_backend.reject_vpc_peering_connection(vpc_pcx_id)
template = self.response_template(
REJECT_VPC_PEERING_CONNECTION_RESPONSE)

View File

@ -1,42 +1,41 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, vpc_ids_from_querystring
from moto.ec2.utils import filters_from_querystring
class VPCs(BaseResponse):
def create_vpc(self):
cidr_block = self.querystring.get('CidrBlock')[0]
instance_tenancy = self.querystring.get(
'InstanceTenancy', ['default'])[0]
cidr_block = self._get_param('CidrBlock')
instance_tenancy = self._get_param('InstanceTenancy', if_none='default')
vpc = self.ec2_backend.create_vpc(cidr_block, instance_tenancy)
template = self.response_template(CREATE_VPC_RESPONSE)
return template.render(vpc=vpc)
def delete_vpc(self):
vpc_id = self.querystring.get('VpcId')[0]
vpc_id = self._get_param('VpcId')
vpc = self.ec2_backend.delete_vpc(vpc_id)
template = self.response_template(DELETE_VPC_RESPONSE)
return template.render(vpc=vpc)
def describe_vpcs(self):
vpc_ids = vpc_ids_from_querystring(self.querystring)
vpc_ids = self._get_multi_param('VpcId')
filters = filters_from_querystring(self.querystring)
vpcs = self.ec2_backend.get_all_vpcs(vpc_ids=vpc_ids, filters=filters)
template = self.response_template(DESCRIBE_VPCS_RESPONSE)
return template.render(vpcs=vpcs)
def describe_vpc_attribute(self):
vpc_id = self.querystring.get('VpcId')[0]
attribute = self.querystring.get('Attribute')[0]
vpc_id = self._get_param('VpcId')
attribute = self._get_param('Attribute')
attr_name = camelcase_to_underscores(attribute)
value = self.ec2_backend.describe_vpc_attribute(vpc_id, attr_name)
template = self.response_template(DESCRIBE_VPC_ATTRIBUTE_RESPONSE)
return template.render(vpc_id=vpc_id, attribute=attribute, value=value)
def modify_vpc_attribute(self):
vpc_id = self.querystring.get('VpcId')[0]
vpc_id = self._get_param('VpcId')
for attribute in ('EnableDnsSupport', 'EnableDnsHostnames'):
if self.querystring.get('%s.Value' % attribute):

View File

@ -1,30 +1,29 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring, sequence_from_querystring
from moto.ec2.utils import filters_from_querystring
class VPNConnections(BaseResponse):
def create_vpn_connection(self):
type = self.querystring.get("Type", [None])[0]
cgw_id = self.querystring.get("CustomerGatewayId", [None])[0]
vgw_id = self.querystring.get("VPNGatewayId", [None])[0]
static_routes = self.querystring.get("StaticRoutesOnly", [None])[0]
type = self._get_param('Type')
cgw_id = self._get_param('CustomerGatewayId')
vgw_id = self._get_param('VPNGatewayId')
static_routes = self._get_param('StaticRoutesOnly')
vpn_connection = self.ec2_backend.create_vpn_connection(
type, cgw_id, vgw_id, static_routes_only=static_routes)
template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection)
def delete_vpn_connection(self):
vpn_connection_id = self.querystring.get('VpnConnectionId')[0]
vpn_connection_id = self._get_param('VpnConnectionId')
vpn_connection = self.ec2_backend.delete_vpn_connection(
vpn_connection_id)
template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection)
def describe_vpn_connections(self):
vpn_connection_ids = sequence_from_querystring(
'VpnConnectionId', self.querystring)
vpn_connection_ids = self._get_multi_param('VpnConnectionId')
filters = filters_from_querystring(self.querystring)
vpn_connections = self.ec2_backend.get_all_vpn_connections(
vpn_connection_ids=vpn_connection_ids, filters=filters)

View File

@ -174,62 +174,6 @@ def split_route_id(route_id):
return values[0], values[1]
def instance_ids_from_querystring(querystring_dict):
instance_ids = []
for key, value in querystring_dict.items():
if 'InstanceId' in key:
instance_ids.append(value[0])
return instance_ids
def image_ids_from_querystring(querystring_dict):
image_ids = []
for key, value in querystring_dict.items():
if 'ImageId' in key:
image_ids.append(value[0])
return image_ids
def executable_users_from_querystring(querystring_dict):
user_ids = []
for key, value in querystring_dict.items():
if 'ExecutableBy' in key:
user_ids.append(value[0])
return user_ids
def route_table_ids_from_querystring(querystring_dict):
route_table_ids = []
for key, value in querystring_dict.items():
if 'RouteTableId' in key:
route_table_ids.append(value[0])
return route_table_ids
def network_acl_ids_from_querystring(querystring_dict):
network_acl_ids = []
for key, value in querystring_dict.items():
if 'NetworkAclId' in key:
network_acl_ids.append(value[0])
return network_acl_ids
def vpc_ids_from_querystring(querystring_dict):
vpc_ids = []
for key, value in querystring_dict.items():
if 'VpcId' in key:
vpc_ids.append(value[0])
return vpc_ids
def sequence_from_querystring(parameter, querystring_dict):
parameter_values = []
for key, value in querystring_dict.items():
if parameter in key:
parameter_values.append(value[0])
return parameter_values
def tags_from_query_string(querystring_dict):
prefix = 'Tag'
suffix = 'Key'
@ -286,11 +230,6 @@ def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration'
return response_values
def optional_from_querystring(parameter, querystring):
parameter_array = querystring.get(parameter)
return parameter_array[0] if parameter_array else None
def filters_from_querystring(querystring_dict):
response_values = {}
for key, value in querystring_dict.items():
@ -319,14 +258,6 @@ def dict_from_querystring(parameter, querystring_dict):
return use_dict
def keypair_names_from_querystring(querystring_dict):
keypair_names = []
for key, value in querystring_dict.items():
if 'KeyName' in key:
keypair_names.append(value[0])
return keypair_names
def get_object_value(obj, attr):
keys = attr.split('.')
val = obj
@ -335,6 +266,11 @@ def get_object_value(obj, attr):
val = getattr(val, key)
elif isinstance(val, dict):
val = val[key]
elif isinstance(val, list):
for item in val:
item_val = get_object_value(item, key)
if item_val:
return item_val
else:
return None
return val
@ -385,14 +321,17 @@ filter_dict_attribute_mapping = {
'state-reason-code': '_state_reason.code',
'source-dest-check': 'source_dest_check',
'vpc-id': 'vpc_id',
'group-id': 'security_groups',
'instance.group-id': 'security_groups',
'group-id': 'security_groups.id',
'instance.group-id': 'security_groups.id',
'instance.group-name': 'security_groups.name',
'instance-type': 'instance_type',
'private-ip-address': 'private_ip',
'ip-address': 'public_ip',
'availability-zone': 'placement',
'architecture': 'architecture',
'image-id': 'image_id'
'image-id': 'image_id',
'network-interface.private-dns-name': 'private_dns',
'private-dns-name': 'private_dns'
}

22
moto/ecr/exceptions.py Normal file
View File

@ -0,0 +1,22 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
class RepositoryNotFoundException(RESTError):
code = 400
def __init__(self, repository_name, registry_id):
super(RepositoryNotFoundException, self).__init__(
error_type="RepositoryNotFoundException",
message="The repository with name '{0}' does not exist in the registry "
"with id '{1}'".format(repository_name, registry_id))
class ImageNotFoundException(RESTError):
code = 400
def __init__(self, image_id, repository_name, registry_id):
super(ImageNotFoundException, self).__init__(
error_type="ImageNotFoundException",
message="The image with imageId {0} does not exist within the repository with name '{1}' "
"in the registry with id '{2}'".format(image_id, repository_name, registry_id))

View File

@ -7,6 +7,11 @@ from moto.ec2 import ec2_backends
from copy import copy
import hashlib
from moto.ecr.exceptions import ImageNotFoundException, RepositoryNotFoundException
DEFAULT_REGISTRY_ID = '012345678910'
class BaseObject(BaseModel):
@ -35,14 +40,13 @@ class BaseObject(BaseModel):
class Repository(BaseObject):
def __init__(self, repository_name):
self.arn = 'arn:aws:ecr:us-east-1:012345678910:repository/{0}'.format(
repository_name)
self.registry_id = DEFAULT_REGISTRY_ID
self.arn = 'arn:aws:ecr:us-east-1:{0}:repository/{1}'.format(
self.registry_id, repository_name)
self.name = repository_name
# self.created = datetime.utcnow()
self.uri = '012345678910.dkr.ecr.us-east-1.amazonaws.com/{0}'.format(
repository_name
)
self.registry_id = '012345678910'
self.uri = '{0}.dkr.ecr.us-east-1.amazonaws.com/{1}'.format(
self.registry_id, repository_name)
self.images = []
@property
@ -93,7 +97,7 @@ class Repository(BaseObject):
class Image(BaseObject):
def __init__(self, tag, manifest, repository, registry_id="012345678910"):
def __init__(self, tag, manifest, repository, registry_id=DEFAULT_REGISTRY_ID):
self.image_tag = tag
self.image_manifest = manifest
self.image_size_in_bytes = 50 * 1024 * 1024
@ -151,6 +155,11 @@ class ECRBackend(BaseBackend):
"""
maxResults and nextToken not implemented
"""
if repository_names:
for repository_name in repository_names:
if repository_name not in self.repositories:
raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID)
repositories = []
for repository in self.repositories.values():
# If a registry_id was supplied, ensure this repository matches
@ -170,11 +179,11 @@ class ECRBackend(BaseBackend):
self.repositories[repository_name] = repository
return repository
def delete_repository(self, respository_name, registry_id=None):
if respository_name in self.repositories:
return self.repositories.pop(respository_name)
def delete_repository(self, repository_name, registry_id=None):
if repository_name in self.repositories:
return self.repositories.pop(repository_name)
else:
raise Exception("{0} is not a repository".format(respository_name))
raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID)
def list_images(self, repository_name, registry_id=None):
"""
@ -198,17 +207,27 @@ class ECRBackend(BaseBackend):
if repository_name in self.repositories:
repository = self.repositories[repository_name]
else:
raise Exception("{0} is not a repository".format(repository_name))
raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID)
if image_ids:
response = set()
for image_id in image_ids:
if 'imageDigest' in image_id:
desired_digest = image_id['imageDigest']
response.update([i for i in repository.images if i.get_image_digest() == desired_digest])
if 'imageTag' in image_id:
desired_tag = image_id['imageTag']
response.update([i for i in repository.images if i.image_tag == desired_tag])
found = False
for image in repository.images:
if (('imageDigest' in image_id and image.get_image_digest() == image_id['imageDigest']) or
('imageTag' in image_id and image.image_tag == image_id['imageTag'])):
found = True
response.add(image)
if not found:
image_id_representation = "{imageDigest:'%s', imageTag:'%s'}" % (
image_id.get('imageDigest', 'null'),
image_id.get('imageTag', 'null'),
)
raise ImageNotFoundException(
image_id=image_id_representation,
repository_name=repository_name,
registry_id=registry_id or DEFAULT_REGISTRY_ID)
else:
response = []
for image in repository.images:

View File

@ -45,7 +45,8 @@ class ECRResponse(BaseResponse):
def delete_repository(self):
repository_str = self._get_param('repositoryName')
repository = self.ecr_backend.delete_repository(repository_str)
registry_id = self._get_param('registryId')
repository = self.ecr_backend.delete_repository(repository_str, registry_id)
return json.dumps({
'repository': repository.response_object
})

View File

@ -114,7 +114,7 @@ class TaskDefinition(BaseObject):
family = properties.get(
'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6)))
container_definitions = properties['ContainerDefinitions']
volumes = properties['Volumes']
volumes = properties.get('Volumes')
ecs_backend = ecs_backends[region_name]
return ecs_backend.register_task_definition(
@ -127,7 +127,7 @@ class TaskDefinition(BaseObject):
family = properties.get(
'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6)))
container_definitions = properties['ContainerDefinitions']
volumes = properties['Volumes']
volumes = properties.get('Volumes')
if (original_resource.family != family or
original_resource.container_definitions != container_definitions or
original_resource.volumes != volumes):

View File

@ -18,8 +18,8 @@ class EC2ContainerServiceResponse(BaseResponse):
except ValueError:
return {}
def _get_param(self, param):
return self.request_params.get(param, None)
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)
def create_cluster(self):
cluster_name = self._get_param('clusterName')

View File

@ -42,7 +42,7 @@ class SubnetNotFoundError(ELBClientError):
class TargetGroupNotFoundError(ELBClientError):
def __init__(self):
super(TooManyTagsError, self).__init__(
super(TargetGroupNotFoundError, self).__init__(
"TargetGroupNotFound",
"The specified target group does not exist.")
@ -101,3 +101,85 @@ class EmptyListenersError(ELBClientError):
super(EmptyListenersError, self).__init__(
"ValidationError",
"Listeners cannot be empty")
class PriorityInUseError(ELBClientError):
def __init__(self):
super(PriorityInUseError, self).__init__(
"PriorityInUse",
"The specified priority is in use.")
class InvalidConditionFieldError(ELBClientError):
def __init__(self, invalid_name):
super(InvalidConditionFieldError, self).__init__(
"ValidationError",
"Condition field '%s' must be one of '[path-pattern, host-header]" % (invalid_name))
class InvalidConditionValueError(ELBClientError):
def __init__(self, msg):
super(InvalidConditionValueError, self).__init__(
"ValidationError", msg)
class InvalidActionTypeError(ELBClientError):
def __init__(self, invalid_name, index):
super(InvalidActionTypeError, self).__init__(
"ValidationError",
"1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward]" % (invalid_name, index)
)
class ActionTargetGroupNotFoundError(ELBClientError):
def __init__(self, arn):
super(ActionTargetGroupNotFoundError, self).__init__(
"TargetGroupNotFound",
"Target group '%s' not found" % arn
)
class InvalidDescribeRulesRequest(ELBClientError):
def __init__(self, msg):
super(InvalidDescribeRulesRequest, self).__init__(
"ValidationError", msg
)
class RuleNotFoundError(ELBClientError):
def __init__(self):
super(RuleNotFoundError, self).__init__(
"RuleNotFound",
"The specified rule does not exist.")
class DuplicatePriorityError(ELBClientError):
def __init__(self, invalid_value):
super(DuplicatePriorityError, self).__init__(
"ValidationError",
"Priority '%s' was provided multiple times" % invalid_value)
class InvalidTargetGroupNameError(ELBClientError):
def __init__(self, msg):
super(InvalidTargetGroupNameError, self).__init__(
"ValidationError", msg
)
class InvalidModifyRuleArgumentsError(ELBClientError):
def __init__(self):
super(InvalidModifyRuleArgumentsError, self).__init__(
"ValidationError",
"Either conditions or actions must be specified"
)

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import datetime
import re
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.ec2.models import ec2_backends
@ -14,6 +15,16 @@ from .exceptions import (
SubnetNotFoundError,
TargetGroupNotFoundError,
TooManyTagsError,
PriorityInUseError,
InvalidConditionFieldError,
InvalidConditionValueError,
InvalidActionTypeError,
ActionTargetGroupNotFoundError,
InvalidDescribeRulesRequest,
RuleNotFoundError,
DuplicatePriorityError,
InvalidTargetGroupNameError,
InvalidModifyRuleArgumentsError
)
@ -54,6 +65,7 @@ class FakeTargetGroup(BaseModel):
self.healthy_threshold_count = healthy_threshold_count
self.unhealthy_threshold_count = unhealthy_threshold_count
self.load_balancer_arns = []
self.tags = {}
self.attributes = {
'deregistration_delay.timeout_seconds': 300,
@ -71,10 +83,15 @@ class FakeTargetGroup(BaseModel):
def deregister(self, targets):
for target in targets:
t = self.targets.pop(target['id'])
t = self.targets.pop(target['id'], None)
if not t:
raise InvalidTargetError()
def add_tag(self, key, value):
if len(self.tags) >= 10 and key not in self.tags:
raise TooManyTagsError()
self.tags[key] = value
def health_for(self, target):
t = self.targets.get(target['id'])
if t is None:
@ -92,6 +109,36 @@ class FakeListener(BaseModel):
self.ssl_policy = ssl_policy
self.certificate = certificate
self.default_actions = default_actions
self._non_default_rules = []
self._default_rule = FakeRule(
listener_arn=self.arn,
conditions=[],
priority='default',
actions=default_actions,
is_default=True
)
@property
def rules(self):
return self._non_default_rules + [self._default_rule]
def remove_rule(self, rule):
self._non_default_rules.remove(rule)
def register(self, rule):
self._non_default_rules.append(rule)
self._non_default_rules = sorted(self._non_default_rules, key=lambda x: x.priority)
class FakeRule(BaseModel):
def __init__(self, listener_arn, conditions, priority, actions, is_default):
self.listener_arn = listener_arn
self.arn = listener_arn.replace(':listener/', ':listener-rule/') + "/%s" % (id(self))
self.conditions = conditions
self.priority = priority # int or 'default'
self.actions = actions
self.is_default = is_default
class FakeBackend(BaseModel):
@ -181,7 +228,73 @@ class ELBv2Backend(BaseBackend):
self.load_balancers[arn] = new_load_balancer
return new_load_balancer
def create_rule(self, listener_arn, conditions, priority, actions):
listeners = self.describe_listeners(None, [listener_arn])
if not listeners:
raise ListenerNotFoundError()
listener = listeners[0]
# validate conditions
for condition in conditions:
field = condition['field']
if field not in ['path-pattern', 'host-header']:
raise InvalidConditionFieldError(field)
values = condition['values']
if len(values) == 0:
raise InvalidConditionValueError('A condition value must be specified')
if len(values) > 1:
raise InvalidConditionValueError(
"The '%s' field contains too many values; the limit is '1'" % field
)
# TODO: check pattern of value for 'host-header'
# TODO: check pattern of value for 'path-pattern'
# validate Priority
for rule in listener.rules:
if rule.priority == priority:
raise PriorityInUseError()
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type not in ['forward']:
raise InvalidActionTypeError(action_type, index)
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
# TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules'
# create rule
rule = FakeRule(listener.arn, conditions, priority, actions, is_default=False)
listener.register(rule)
return [rule]
def create_target_group(self, name, **kwargs):
if len(name) > 32:
raise InvalidTargetGroupNameError(
"Target group name '%s' cannot be longer than '32' characters" % name
)
if not re.match('^[a-zA-Z0-9\-]+$', name):
raise InvalidTargetGroupNameError(
"Target group name '%s' can only contain characters that are alphanumeric characters or hyphens(-)" % name
)
# undocumented validation
if not re.match('(?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$', name):
raise InvalidTargetGroupNameError(
"1 validation error detected: Value '%s' at 'targetGroup.targetGroupArn.targetGroupName' failed to satisfy constraint: Member must satisfy regular expression pattern: (?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$" % name
)
if name.startswith('-') or name.endswith('-'):
raise InvalidTargetGroupNameError(
"Target group name '%s' cannot begin or end with '-'" % name
)
for target_group in self.target_groups.values():
if target_group.name == name:
raise DuplicateTargetGroupName()
@ -233,6 +346,29 @@ class ELBv2Backend(BaseBackend):
return matched_balancers
def describe_rules(self, listener_arn, rule_arns):
if listener_arn is None and not rule_arns:
raise InvalidDescribeRulesRequest(
"You must specify either listener rule ARNs or a listener ARN"
)
if listener_arn is not None and rule_arns is not None:
raise InvalidDescribeRulesRequest(
'Listener rule ARNs and a listener ARN cannot be specified at the same time'
)
if listener_arn:
listener = self.describe_listeners(None, [listener_arn])[0]
return listener.rules
# search for rule arns
matched_rules = []
for load_balancer_arn in self.load_balancers:
listeners = self.load_balancers.get(load_balancer_arn).listeners.values()
for listener in listeners:
for rule in listener.rules:
if rule.arn in rule_arns:
matched_rules.append(rule)
return matched_rules
def describe_target_groups(self, load_balancer_arn, target_group_arns, names):
if load_balancer_arn:
if load_balancer_arn not in self.load_balancers:
@ -249,7 +385,7 @@ class ELBv2Backend(BaseBackend):
matched = []
for name in names:
found = None
for target_group in self.target_groups:
for target_group in self.target_groups.values():
if target_group.name == name:
found = target_group
if not found:
@ -277,19 +413,78 @@ class ELBv2Backend(BaseBackend):
def delete_load_balancer(self, arn):
self.load_balancers.pop(arn, None)
def delete_rule(self, arn):
for load_balancer_arn in self.load_balancers:
listeners = self.load_balancers.get(load_balancer_arn).listeners.values()
for listener in listeners:
for rule in listener.rules:
if rule.arn == arn:
listener.remove_rule(rule)
return
# should raise RuleNotFound Error according to the AWS API doc
# however, boto3 does't raise error even if rule is not found
def delete_target_group(self, target_group_arn):
target_group = self.target_groups.pop(target_group_arn)
target_group = self.target_groups.pop(target_group_arn, None)
if target_group:
return target_group
raise TargetGroupNotFoundError()
def delete_listener(self, listener_arn):
for load_balancer in self.load_balancers.values():
listener = load_balancer.listeners.pop(listener_arn)
listener = load_balancer.listeners.pop(listener_arn, None)
if listener:
return listener
raise ListenerNotFoundError()
def modify_rule(self, rule_arn, conditions, actions):
# if conditions or actions is empty list, do not update the attributes
if not conditions and not actions:
raise InvalidModifyRuleArgumentsError()
rules = self.describe_rules(listener_arn=None, rule_arns=[rule_arn])
if not rules:
raise RuleNotFoundError()
rule = rules[0]
if conditions:
for condition in conditions:
field = condition['field']
if field not in ['path-pattern', 'host-header']:
raise InvalidConditionFieldError(field)
values = condition['values']
if len(values) == 0:
raise InvalidConditionValueError('A condition value must be specified')
if len(values) > 1:
raise InvalidConditionValueError(
"The '%s' field contains too many values; the limit is '1'" % field
)
# TODO: check pattern of value for 'host-header'
# TODO: check pattern of value for 'path-pattern'
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
if actions:
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type not in ['forward']:
raise InvalidActionTypeError(action_type, index)
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
# TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules'
# modify rule
if conditions:
rule.conditions = conditions
if actions:
rule.actions = actions
return [rule]
def register_targets(self, target_group_arn, instances):
target_group = self.target_groups.get(target_group_arn)
if target_group is None:
@ -311,6 +506,39 @@ class ELBv2Backend(BaseBackend):
targets = target_group.targets.values()
return [target_group.health_for(target) for target in targets]
def set_rule_priorities(self, rule_priorities):
# validate
priorities = [rule_priority['priority'] for rule_priority in rule_priorities]
for priority in set(priorities):
if priorities.count(priority) > 1:
raise DuplicatePriorityError(priority)
# validate
for rule_priority in rule_priorities:
given_rule_arn = rule_priority['rule_arn']
priority = rule_priority['priority']
_given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn])
if not _given_rules:
raise RuleNotFoundError()
given_rule = _given_rules[0]
listeners = self.describe_listeners(None, [given_rule.listener_arn])
listener = listeners[0]
for rule_in_listener in listener.rules:
if rule_in_listener.priority == priority:
raise PriorityInUseError()
# modify
modified_rules = []
for rule_priority in rule_priorities:
given_rule_arn = rule_priority['rule_arn']
priority = rule_priority['priority']
_given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn])
if not _given_rules:
raise RuleNotFoundError()
given_rule = _given_rules[0]
given_rule.priority = priority
modified_rules.append(given_rule)
return modified_rules
elbv2_backends = {}
for region in ec2_backends.keys():

View File

@ -28,6 +28,30 @@ class ELBV2Response(BaseResponse):
template = self.response_template(CREATE_LOAD_BALANCER_TEMPLATE)
return template.render(load_balancer=load_balancer)
def create_rule(self):
lister_arn = self._get_param('ListenerArn')
_conditions = self._get_list_prefix('Conditions.member')
conditions = []
for _condition in _conditions:
condition = {}
condition['field'] = _condition['field']
values = sorted(
[e for e in _condition.items() if e[0].startswith('values.member')],
key=lambda x: x[0]
)
condition['values'] = [e[1] for e in values]
conditions.append(condition)
priority = self._get_int_param('Priority')
actions = self._get_list_prefix('Actions.member')
rules = self.elbv2_backend.create_rule(
listener_arn=lister_arn,
conditions=conditions,
priority=priority,
actions=actions
)
template = self.response_template(CREATE_RULE_TEMPLATE)
return template.render(rules=rules)
def create_target_group(self):
name = self._get_param('Name')
vpc_id = self._get_param('VpcId')
@ -100,6 +124,26 @@ class ELBV2Response(BaseResponse):
template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE)
return template.render(load_balancers=load_balancers_resp, marker=next_marker)
def describe_rules(self):
listener_arn = self._get_param('ListenerArn')
rule_arns = self._get_multi_param('RuleArns.member') if any(k for k in list(self.querystring.keys()) if k.startswith('RuleArns.member')) else None
all_rules = self.elbv2_backend.describe_rules(listener_arn, rule_arns)
all_arns = [rule.arn for rule in all_rules]
page_size = self._get_int_param('PageSize', 50) # set 50 for temporary
marker = self._get_param('Marker')
if marker:
start = all_arns.index(marker) + 1
else:
start = 0
rules_resp = all_rules[start:start + page_size]
next_marker = None
if len(all_rules) > start + page_size:
next_marker = rules_resp[-1].arn
template = self.response_template(DESCRIBE_RULES_TEMPLATE)
return template.render(rules=rules_resp, marker=next_marker)
def describe_target_groups(self):
load_balancer_arn = self._get_param('LoadBalancerArn')
target_group_arns = self._get_multi_param('TargetGroupArns.member')
@ -133,6 +177,12 @@ class ELBV2Response(BaseResponse):
template = self.response_template(DELETE_LOAD_BALANCER_TEMPLATE)
return template.render()
def delete_rule(self):
arn = self._get_param('RuleArn')
self.elbv2_backend.delete_rule(arn)
template = self.response_template(DELETE_RULE_TEMPLATE)
return template.render()
def delete_target_group(self):
arn = self._get_param('TargetGroupArn')
self.elbv2_backend.delete_target_group(arn)
@ -145,6 +195,28 @@ class ELBV2Response(BaseResponse):
template = self.response_template(DELETE_LISTENER_TEMPLATE)
return template.render()
def modify_rule(self):
rule_arn = self._get_param('RuleArn')
_conditions = self._get_list_prefix('Conditions.member')
conditions = []
for _condition in _conditions:
condition = {}
condition['field'] = _condition['field']
values = sorted(
[e for e in _condition.items() if e[0].startswith('values.member')],
key=lambda x: x[0]
)
condition['values'] = [e[1] for e in values]
conditions.append(condition)
actions = self._get_list_prefix('Actions.member')
rules = self.elbv2_backend.modify_rule(
rule_arn=rule_arn,
conditions=conditions,
actions=actions
)
template = self.response_template(MODIFY_RULE_TEMPLATE)
return template.render(rules=rules)
def modify_target_group_attributes(self):
target_group_arn = self._get_param('TargetGroupArn')
target_group = self.elbv2_backend.target_groups.get(target_group_arn)
@ -182,14 +254,29 @@ class ELBV2Response(BaseResponse):
template = self.response_template(DESCRIBE_TARGET_HEALTH_TEMPLATE)
return template.render(target_health_descriptions=target_health_descriptions)
def set_rule_priorities(self):
rule_priorities = self._get_list_prefix('RulePriorities.member')
for rule_priority in rule_priorities:
rule_priority['priority'] = int(rule_priority['priority'])
rules = self.elbv2_backend.set_rule_priorities(rule_priorities)
template = self.response_template(SET_RULE_PRIORITIES_TEMPLATE)
return template.render(rules=rules)
def add_tags(self):
resource_arns = self._get_multi_param('ResourceArns.member')
for arn in resource_arns:
load_balancer = self.elbv2_backend.load_balancers.get(arn)
if not load_balancer:
if ':targetgroup' in arn:
resource = self.elbv2_backend.target_groups.get(arn)
if not resource:
raise TargetGroupNotFoundError()
elif ':loadbalancer' in arn:
resource = self.elbv2_backend.load_balancers.get(arn)
if not resource:
raise LoadBalancerNotFoundError()
else:
raise LoadBalancerNotFoundError()
self._add_tags(load_balancer)
self._add_tags(resource)
template = self.response_template(ADD_TAGS_TEMPLATE)
return template.render()
@ -199,30 +286,41 @@ class ELBV2Response(BaseResponse):
tag_keys = self._get_multi_param('TagKeys.member')
for arn in resource_arns:
load_balancer = self.elbv2_backend.load_balancers.get(arn)
if not load_balancer:
if ':targetgroup' in arn:
resource = self.elbv2_backend.target_groups.get(arn)
if not resource:
raise TargetGroupNotFoundError()
elif ':loadbalancer' in arn:
resource = self.elbv2_backend.load_balancers.get(arn)
if not resource:
raise LoadBalancerNotFoundError()
else:
raise LoadBalancerNotFoundError()
[load_balancer.remove_tag(key) for key in tag_keys]
[resource.remove_tag(key) for key in tag_keys]
template = self.response_template(REMOVE_TAGS_TEMPLATE)
return template.render()
def describe_tags(self):
elbs = []
for key, value in self.querystring.items():
if "ResourceArns.member" in key:
number = key.split('.')[2]
load_balancer_arn = self._get_param(
'ResourceArns.member.{0}'.format(number))
elb = self.elbv2_backend.load_balancers.get(load_balancer_arn)
if not elb:
resource_arns = self._get_multi_param('ResourceArns.member')
resources = []
for arn in resource_arns:
if ':targetgroup' in arn:
resource = self.elbv2_backend.target_groups.get(arn)
if not resource:
raise TargetGroupNotFoundError()
elif ':loadbalancer' in arn:
resource = self.elbv2_backend.load_balancers.get(arn)
if not resource:
raise LoadBalancerNotFoundError()
elbs.append(elb)
else:
raise LoadBalancerNotFoundError()
resources.append(resource)
template = self.response_template(DESCRIBE_TAGS_TEMPLATE)
return template.render(load_balancers=elbs)
return template.render(resources=resources)
def _add_tags(self, elb):
def _add_tags(self, resource):
tag_values = []
tag_keys = []
@ -244,7 +342,7 @@ class ELBV2Response(BaseResponse):
raise DuplicateTagKeysError(counts[0])
for tag_key, tag_value in zip(tag_keys, tag_values):
elb.add_tag(tag_key, tag_value)
resource.add_tag(tag_key, tag_value)
ADD_TAGS_TEMPLATE = """<AddTagsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
@ -264,11 +362,11 @@ REMOVE_TAGS_TEMPLATE = """<RemoveTagsResponse xmlns="http://elasticloadbalancing
DESCRIBE_TAGS_TEMPLATE = """<DescribeTagsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeTagsResult>
<TagDescriptions>
{% for load_balancer in load_balancers %}
{% for resource in resources %}
<member>
<ResourceArn>{{ load_balancer.arn }}</ResourceArn>
<ResourceArn>{{ resource.arn }}</ResourceArn>
<Tags>
{% for key, value in load_balancer.tags.items() %}
{% for key, value in resource.tags.items() %}
<member>
<Value>{{ value }}</Value>
<Key>{{ key }}</Key>
@ -321,6 +419,43 @@ CREATE_LOAD_BALANCER_TEMPLATE = """<CreateLoadBalancerResponse xmlns="http://ela
</ResponseMetadata>
</CreateLoadBalancerResponse>"""
CREATE_RULE_TEMPLATE = """<CreateRuleResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<CreateRuleResult>
<Rules>
{% for rule in rules %}
<member>
<IsDefault>{{ "true" if rule.is_default else "false" }}</IsDefault>
<Conditions>
{% for condition in rule.conditions %}
<member>
<Field>{{ condition["field"] }}</Field>
<Values>
{% for value in condition["values"] %}
<member>{{ value }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</Conditions>
<Priority>{{ rule.priority }}</Priority>
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member>
{% endfor %}
</Actions>
<RuleArn>{{ rule.arn }}</RuleArn>
</member>
{% endfor %}
</Rules>
</CreateRuleResult>
<ResponseMetadata>
<RequestId>c5478c83-f397-11e5-bb98-57195a6eb84a</RequestId>
</ResponseMetadata>
</CreateRuleResponse>"""
CREATE_TARGET_GROUP_TEMPLATE = """<CreateTargetGroupResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<CreateTargetGroupResult>
<TargetGroups>
@ -387,6 +522,13 @@ DELETE_LOAD_BALANCER_TEMPLATE = """<DeleteLoadBalancerResponse xmlns="http://ela
</ResponseMetadata>
</DeleteLoadBalancerResponse>"""
DELETE_RULE_TEMPLATE = """<DeleteRuleResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DeleteRuleResult/>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
</ResponseMetadata>
</DeleteRuleResponse>"""
DELETE_TARGET_GROUP_TEMPLATE = """<DeleteTargetGroupResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DeleteTargetGroupResult/>
<ResponseMetadata>
@ -442,6 +584,45 @@ DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http
</ResponseMetadata>
</DescribeLoadBalancersResponse>"""
DESCRIBE_RULES_TEMPLATE = """<DescribeRulesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeRulesResult>
<Rules>
{% for rule in rules %}
<member>
<IsDefault>{{ "true" if rule.is_default else "false" }}</IsDefault>
<Conditions>
{% for condition in rule.conditions %}
<member>
<Field>{{ condition["field"] }}</Field>
<Values>
{% for value in condition["values"] %}
<member>{{ value }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</Conditions>
<Priority>{{ rule.priority }}</Priority>
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member>
{% endfor %}
</Actions>
<RuleArn>{{ rule.arn }}</RuleArn>
</member>
{% endfor %}
</Rules>
{% if marker %}
<NextMarker>{{ marker }}</NextMarker>
{% endif %}
</DescribeRulesResult>
<ResponseMetadata>
<RequestId>74926cf3-f3a3-11e5-b543-9f2c3fbb9bee</RequestId>
</ResponseMetadata>
</DescribeRulesResponse>"""
DESCRIBE_TARGET_GROUPS_TEMPLATE = """<DescribeTargetGroupsResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<DescribeTargetGroupsResult>
@ -505,7 +686,7 @@ DESCRIBE_LISTENERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http://el
{% if listener.certificate %}
<Certificates>
<member>
<CertificateArn>{{ listener.certificate }} </CertificateArn>
<CertificateArn>{{ listener.certificate }}</CertificateArn>
</member>
</Certificates>
{% endif %}
@ -544,6 +725,43 @@ CONFIGURE_HEALTH_CHECK_TEMPLATE = """<ConfigureHealthCheckResponse xmlns="http:/
</ResponseMetadata>
</ConfigureHealthCheckResponse>"""
MODIFY_RULE_TEMPLATE = """<ModifyRuleResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<ModifyRuleResult>
<Rules>
{% for rule in rules %}
<member>
<IsDefault>{{ "true" if rule.is_default else "false" }}</IsDefault>
<Conditions>
{% for condition in rule.conditions %}
<member>
<Field>{{ condition["field"] }}</Field>
<Values>
{% for value in condition["values"] %}
<member>{{ value }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</Conditions>
<Priority>{{ rule.priority }}</Priority>
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member>
{% endfor %}
</Actions>
<RuleArn>{{ rule.arn }}</RuleArn>
</member>
{% endfor %}
</Rules>
</ModifyRuleResult>
<ResponseMetadata>
<RequestId>c5478c83-f397-11e5-bb98-57195a6eb84a</RequestId>
</ResponseMetadata>
</ModifyRuleResponse>"""
MODIFY_TARGET_GROUP_ATTRIBUTES_TEMPLATE = """<ModifyTargetGroupAttributesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<ModifyTargetGroupAttributesResult>
<Attributes>
@ -703,3 +921,40 @@ DESCRIBE_TARGET_HEALTH_TEMPLATE = """<DescribeTargetHealthResponse xmlns="http:/
<RequestId>c534f810-f389-11e5-9192-3fff33344cfa</RequestId>
</ResponseMetadata>
</DescribeTargetHealthResponse>"""
SET_RULE_PRIORITIES_TEMPLATE = """<SetRulePrioritiesResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2015-12-01/">
<SetRulePrioritiesResult>
<Rules>
{% for rule in rules %}
<member>
<IsDefault>{{ "true" if rule.is_default else "false" }}</IsDefault>
<Conditions>
{% for condition in rule.conditions %}
<member>
<Field>{{ condition["field"] }}</Field>
<Values>
{% for value in condition["values"] %}
<member>{{ value }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</Conditions>
<Priority>{{ rule.priority }}</Priority>
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member>
{% endfor %}
</Actions>
<RuleArn>{{ rule.arn }}</RuleArn>
</member>
{% endfor %}
</Rules>
</SetRulePrioritiesResult>
<ResponseMetadata>
<RequestId>4d7a8036-f3a7-11e5-9c02-8fd20490d5a6</RequestId>
</ResponseMetadata>
</SetRulePrioritiesResponse>"""

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,13 @@
from __future__ import unicode_literals
import base64
from datetime import datetime
import json
import pytz
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .aws_managed_policies import aws_managed_policies_data
from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException
from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id
@ -72,14 +74,32 @@ class ManagedPolicy(Policy):
is_attachable = True
def attach_to_role(self, role):
def attach_to(self, obj):
self.attachment_count += 1
role.managed_policies[self.name] = self
obj.managed_policies[self.name] = self
def detach_from(self, obj):
self.attachment_count -= 1
del obj.managed_policies[self.name]
class AWSManagedPolicy(ManagedPolicy):
"""AWS-managed policy."""
@classmethod
def from_data(cls, name, data):
return cls(name,
default_version_id=data.get('DefaultVersionId'),
path=data.get('Path'),
document=data.get('Document'))
# AWS defines some of its own managed policies and we periodically
# import them via `make aws_managed_policies`
aws_managed_policies = [
AWSManagedPolicy.from_data(name, d) for name, d
in json.loads(aws_managed_policies_data).items()]
class InlinePolicy(Policy):
"""TODO: is this needed?"""
@ -120,6 +140,13 @@ class Role(BaseModel):
def put_policy(self, policy_name, policy_json):
self.policies[policy_name] = policy_json
def delete_policy(self, policy_name):
try:
del self.policies[policy_name]
except KeyError:
raise IAMNotFoundException(
"The role policy with name {0} cannot be found.".format(policy_name))
@property
def physical_resource_id(self):
return self.id
@ -214,6 +241,7 @@ class Group(BaseModel):
)
self.users = []
self.managed_policies = {}
self.policies = {}
def get_cfn_attribute(self, attribute_name):
@ -254,6 +282,7 @@ class User(BaseModel):
self.created = datetime.utcnow()
self.mfa_devices = {}
self.policies = {}
self.managed_policies = {}
self.access_keys = []
self.password = None
self.password_reset_required = False
@ -368,115 +397,6 @@ class User(BaseModel):
)
# predefine AWS managed policies
aws_managed_policies = [
AWSManagedPolicy(
'AmazonElasticMapReduceRole',
default_version_id='v6',
path='/service-role/',
document={
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Resource": "*",
"Action": [
"ec2:AuthorizeSecurityGroupEgress",
"ec2:AuthorizeSecurityGroupIngress",
"ec2:CancelSpotInstanceRequests",
"ec2:CreateNetworkInterface",
"ec2:CreateSecurityGroup",
"ec2:CreateTags",
"ec2:DeleteNetworkInterface",
"ec2:DeleteSecurityGroup",
"ec2:DeleteTags",
"ec2:DescribeAvailabilityZones",
"ec2:DescribeAccountAttributes",
"ec2:DescribeDhcpOptions",
"ec2:DescribeInstanceStatus",
"ec2:DescribeInstances",
"ec2:DescribeKeyPairs",
"ec2:DescribeNetworkAcls",
"ec2:DescribeNetworkInterfaces",
"ec2:DescribePrefixLists",
"ec2:DescribeRouteTables",
"ec2:DescribeSecurityGroups",
"ec2:DescribeSpotInstanceRequests",
"ec2:DescribeSpotPriceHistory",
"ec2:DescribeSubnets",
"ec2:DescribeVpcAttribute",
"ec2:DescribeVpcEndpoints",
"ec2:DescribeVpcEndpointServices",
"ec2:DescribeVpcs",
"ec2:DetachNetworkInterface",
"ec2:ModifyImageAttribute",
"ec2:ModifyInstanceAttribute",
"ec2:RequestSpotInstances",
"ec2:RevokeSecurityGroupEgress",
"ec2:RunInstances",
"ec2:TerminateInstances",
"ec2:DeleteVolume",
"ec2:DescribeVolumeStatus",
"ec2:DescribeVolumes",
"ec2:DetachVolume",
"iam:GetRole",
"iam:GetRolePolicy",
"iam:ListInstanceProfiles",
"iam:ListRolePolicies",
"iam:PassRole",
"s3:CreateBucket",
"s3:Get*",
"s3:List*",
"sdb:BatchPutAttributes",
"sdb:Select",
"sqs:CreateQueue",
"sqs:Delete*",
"sqs:GetQueue*",
"sqs:PurgeQueue",
"sqs:ReceiveMessage"
]
}]
}
),
AWSManagedPolicy(
'AmazonElasticMapReduceforEC2Role',
default_version_id='v2',
path='/service-role/',
document={
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Resource": "*",
"Action": [
"cloudwatch:*",
"dynamodb:*",
"ec2:Describe*",
"elasticmapreduce:Describe*",
"elasticmapreduce:ListBootstrapActions",
"elasticmapreduce:ListClusters",
"elasticmapreduce:ListInstanceGroups",
"elasticmapreduce:ListInstances",
"elasticmapreduce:ListSteps",
"kinesis:CreateStream",
"kinesis:DeleteStream",
"kinesis:DescribeStream",
"kinesis:GetRecords",
"kinesis:GetShardIterator",
"kinesis:MergeShards",
"kinesis:PutRecord",
"kinesis:SplitShard",
"rds:Describe*",
"s3:*",
"sdb:*",
"sns:*",
"sqs:*"
]
}]
}
)
]
# TODO: add more predefined AWS managed policies
class IAMBackend(BaseBackend):
def __init__(self):
@ -487,6 +407,7 @@ class IAMBackend(BaseBackend):
self.users = {}
self.credential_report = None
self.managed_policies = self._init_managed_policies()
self.account_aliases = []
super(IAMBackend, self).__init__()
def _init_managed_policies(self):
@ -495,7 +416,47 @@ class IAMBackend(BaseBackend):
def attach_role_policy(self, policy_arn, role_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
policy = arns[policy_arn]
policy.attach_to_role(self.get_role(role_name))
policy.attach_to(self.get_role(role_name))
def detach_role_policy(self, policy_arn, role_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
policy.detach_from(self.get_role(role_name))
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
def attach_group_policy(self, policy_arn, group_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.attach_to(self.get_group(group_name))
def detach_group_policy(self, policy_arn, group_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.detach_from(self.get_group(group_name))
def attach_user_policy(self, policy_arn, user_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.attach_to(self.get_user(user_name))
def detach_user_policy(self, policy_arn, user_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
try:
policy = arns[policy_arn]
except KeyError:
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.detach_from(self.get_user(user_name))
def create_policy(self, description, path, policy_document, policy_name):
policy = ManagedPolicy(
@ -512,21 +473,15 @@ class IAMBackend(BaseBackend):
def list_attached_role_policies(self, role_name, marker=None, max_items=100, path_prefix='/'):
policies = self.get_role(role_name).managed_policies.values()
return self._filter_attached_policies(policies, marker, max_items, path_prefix)
if path_prefix:
policies = [p for p in policies if p.path.startswith(path_prefix)]
def list_attached_group_policies(self, group_name, marker=None, max_items=100, path_prefix='/'):
policies = self.get_group(group_name).managed_policies.values()
return self._filter_attached_policies(policies, marker, max_items, path_prefix)
policies = sorted(policies, key=lambda policy: policy.name)
start_idx = int(marker) if marker else 0
policies = policies[start_idx:start_idx + max_items]
if len(policies) < max_items:
marker = None
else:
marker = str(start_idx + max_items)
return policies, marker
def list_attached_user_policies(self, user_name, marker=None, max_items=100, path_prefix='/'):
policies = self.get_user(user_name).managed_policies.values()
return self._filter_attached_policies(policies, marker, max_items, path_prefix)
def list_policies(self, marker, max_items, only_attached, path_prefix, scope):
policies = self.managed_policies.values()
@ -540,6 +495,9 @@ class IAMBackend(BaseBackend):
policies = [p for p in policies if not isinstance(
p, AWSManagedPolicy)]
return self._filter_attached_policies(policies, marker, max_items, path_prefix)
def _filter_attached_policies(self, policies, marker, max_items, path_prefix):
if path_prefix:
policies = [p for p in policies if p.path.startswith(path_prefix)]
@ -584,6 +542,10 @@ class IAMBackend(BaseBackend):
role = self.get_role(role_name)
role.put_policy(policy_name, policy_json)
def delete_role_policy(self, role_name, policy_name):
role = self.get_role(role_name)
role.delete_policy(policy_name)
def get_role_policy(self, role_name, policy_name):
role = self.get_role(role_name)
for p, d in role.policies.items():
@ -897,5 +859,15 @@ class IAMBackend(BaseBackend):
report += self.users[user].to_csv()
return base64.b64encode(report.encode('ascii')).decode('ascii')
def list_account_aliases(self):
return self.account_aliases
def create_account_alias(self, alias):
# alias is force updated
self.account_aliases = [alias]
def delete_account_alias(self, alias):
self.account_aliases = []
iam_backend = IAMBackend()

View File

@ -13,6 +13,41 @@ class IamResponse(BaseResponse):
template = self.response_template(ATTACH_ROLE_POLICY_TEMPLATE)
return template.render()
def detach_role_policy(self):
role_name = self._get_param('RoleName')
policy_arn = self._get_param('PolicyArn')
iam_backend.detach_role_policy(policy_arn, role_name)
template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="DetachRolePolicyResponse")
def attach_group_policy(self):
policy_arn = self._get_param('PolicyArn')
group_name = self._get_param('GroupName')
iam_backend.attach_group_policy(policy_arn, group_name)
template = self.response_template(ATTACH_GROUP_POLICY_TEMPLATE)
return template.render()
def detach_group_policy(self):
policy_arn = self._get_param('PolicyArn')
group_name = self._get_param('GroupName')
iam_backend.detach_group_policy(policy_arn, group_name)
template = self.response_template(DETACH_GROUP_POLICY_TEMPLATE)
return template.render()
def attach_user_policy(self):
policy_arn = self._get_param('PolicyArn')
user_name = self._get_param('UserName')
iam_backend.attach_user_policy(policy_arn, user_name)
template = self.response_template(ATTACH_USER_POLICY_TEMPLATE)
return template.render()
def detach_user_policy(self):
policy_arn = self._get_param('PolicyArn')
user_name = self._get_param('UserName')
iam_backend.detach_user_policy(policy_arn, user_name)
template = self.response_template(DETACH_USER_POLICY_TEMPLATE)
return template.render()
def create_policy(self):
description = self._get_param('Description')
path = self._get_param('Path')
@ -33,6 +68,28 @@ class IamResponse(BaseResponse):
template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE)
return template.render(policies=policies, marker=marker)
def list_attached_group_policies(self):
marker = self._get_param('Marker')
max_items = self._get_int_param('MaxItems', 100)
path_prefix = self._get_param('PathPrefix', '/')
group_name = self._get_param('GroupName')
policies, marker = iam_backend.list_attached_group_policies(
group_name, marker=marker, max_items=max_items,
path_prefix=path_prefix)
template = self.response_template(LIST_ATTACHED_GROUP_POLICIES_TEMPLATE)
return template.render(policies=policies, marker=marker)
def list_attached_user_policies(self):
marker = self._get_param('Marker')
max_items = self._get_int_param('MaxItems', 100)
path_prefix = self._get_param('PathPrefix', '/')
user_name = self._get_param('UserName')
policies, marker = iam_backend.list_attached_user_policies(
user_name, marker=marker, max_items=max_items,
path_prefix=path_prefix)
template = self.response_template(LIST_ATTACHED_USER_POLICIES_TEMPLATE)
return template.render(policies=policies, marker=marker)
def list_policies(self):
marker = self._get_param('Marker')
max_items = self._get_int_param('MaxItems', 100)
@ -82,6 +139,13 @@ class IamResponse(BaseResponse):
template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="PutRolePolicyResponse")
def delete_role_policy(self):
role_name = self._get_param('RoleName')
policy_name = self._get_param('PolicyName')
iam_backend.delete_role_policy(role_name, policy_name)
template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="DeleteRolePolicyResponse")
def get_role_policy(self):
role_name = self._get_param('RoleName')
policy_name = self._get_param('PolicyName')
@ -439,6 +503,23 @@ class IamResponse(BaseResponse):
template = self.response_template(CREDENTIAL_REPORT)
return template.render(report=report)
def list_account_aliases(self):
aliases = iam_backend.list_account_aliases()
template = self.response_template(LIST_ACCOUNT_ALIASES_TEMPLATE)
return template.render(aliases=aliases)
def create_account_alias(self):
alias = self._get_param('AccountAlias')
iam_backend.create_account_alias(alias)
template = self.response_template(CREATE_ACCOUNT_ALIAS_TEMPLATE)
return template.render()
def delete_account_alias(self):
alias = self._get_param('AccountAlias')
iam_backend.delete_account_alias(alias)
template = self.response_template(DELETE_ACCOUNT_ALIAS_TEMPLATE)
return template.render()
ATTACH_ROLE_POLICY_TEMPLATE = """<AttachRolePolicyResponse>
<ResponseMetadata>
@ -446,6 +527,36 @@ ATTACH_ROLE_POLICY_TEMPLATE = """<AttachRolePolicyResponse>
</ResponseMetadata>
</AttachRolePolicyResponse>"""
DETACH_ROLE_POLICY_TEMPLATE = """<DetachRolePolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</DetachRolePolicyResponse>"""
ATTACH_USER_POLICY_TEMPLATE = """<AttachUserPolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</AttachUserPolicyResponse>"""
DETACH_USER_POLICY_TEMPLATE = """<DetachUserPolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</DetachUserPolicyResponse>"""
ATTACH_GROUP_POLICY_TEMPLATE = """<AttachGroupPolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</AttachGroupPolicyResponse>"""
DETACH_GROUP_POLICY_TEMPLATE = """<DetachGroupPolicyResponse>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</DetachGroupPolicyResponse>"""
CREATE_POLICY_TEMPLATE = """<CreatePolicyResponse>
<CreatePolicyResult>
<Policy>
@ -486,6 +597,50 @@ LIST_ATTACHED_ROLE_POLICIES_TEMPLATE = """<ListAttachedRolePoliciesResponse>
</ResponseMetadata>
</ListAttachedRolePoliciesResponse>"""
LIST_ATTACHED_GROUP_POLICIES_TEMPLATE = """<ListAttachedGroupPoliciesResponse>
<ListAttachedGroupPoliciesResult>
{% if marker is none %}
<IsTruncated>false</IsTruncated>
{% else %}
<IsTruncated>true</IsTruncated>
<Marker>{{ marker }}</Marker>
{% endif %}
<AttachedPolicies>
{% for policy in policies %}
<member>
<PolicyName>{{ policy.name }}</PolicyName>
<PolicyArn>{{ policy.arn }}</PolicyArn>
</member>
{% endfor %}
</AttachedPolicies>
</ListAttachedGroupPoliciesResult>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</ListAttachedGroupPoliciesResponse>"""
LIST_ATTACHED_USER_POLICIES_TEMPLATE = """<ListAttachedUserPoliciesResponse>
<ListAttachedUserPoliciesResult>
{% if marker is none %}
<IsTruncated>false</IsTruncated>
{% else %}
<IsTruncated>true</IsTruncated>
<Marker>{{ marker }}</Marker>
{% endif %}
<AttachedPolicies>
{% for policy in policies %}
<member>
<PolicyName>{{ policy.name }}</PolicyName>
<PolicyArn>{{ policy.arn }}</PolicyArn>
</member>
{% endfor %}
</AttachedPolicies>
</ListAttachedUserPoliciesResult>
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</ListAttachedUserPoliciesResponse>"""
LIST_POLICIES_TEMPLATE = """<ListPoliciesResponse>
<ListPoliciesResult>
{% if marker is none %}
@ -1113,3 +1268,32 @@ LIST_MFA_DEVICES_TEMPLATE = """<ListMFADevicesResponse>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</ListMFADevicesResponse>"""
LIST_ACCOUNT_ALIASES_TEMPLATE = """<ListAccountAliasesResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<ListAccountAliasesResult>
<IsTruncated>false</IsTruncated>
<AccountAliases>
{% for alias in aliases %}
<member>{{ alias }}</member>
{% endfor %}
</AccountAliases>
</ListAccountAliasesResult>
<ResponseMetadata>
<RequestId>c5a076e9-f1b0-11df-8fbe-45274EXAMPLE</RequestId>
</ResponseMetadata>
</ListAccountAliasesResponse>"""
CREATE_ACCOUNT_ALIAS_TEMPLATE = """<CreateAccountAliasResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<ResponseMetadata>
<RequestId>36b5db08-f1b0-11df-8fbe-45274EXAMPLE</RequestId>
</ResponseMetadata>
</CreateAccountAliasResponse>"""
DELETE_ACCOUNT_ALIAS_TEMPLATE = """<DeleteAccountAliasResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<ResponseMetadata>
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ResponseMetadata>
</DeleteAccountAliasResponse>"""

5
moto/logs/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .models import logs_backends
from ..core.models import base_decorator, deprecated_base_decorator
mock_logs = base_decorator(logs_backends)
mock_logs_deprecated = deprecated_base_decorator(logs_backends)

228
moto/logs/models.py Normal file
View File

@ -0,0 +1,228 @@
from moto.core import BaseBackend
import boto.logs
from moto.core.utils import unix_time_millis
class LogEvent:
_event_id = 0
def __init__(self, ingestion_time, log_event):
self.ingestionTime = ingestion_time
self.timestamp = log_event["timestamp"]
self.message = log_event['message']
self.eventId = self.__class__._event_id
self.__class__._event_id += 1
def to_filter_dict(self):
return {
"eventId": self.eventId,
"ingestionTime": self.ingestionTime,
# "logStreamName":
"message": self.message,
"timestamp": self.timestamp
}
class LogStream:
_log_ids = 0
def __init__(self, region, log_group, name):
self.region = region
self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format(
region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name)
self.creationTime = unix_time_millis()
self.firstEventTimestamp = None
self.lastEventTimestamp = None
self.lastIngestionTime = None
self.logStreamName = name
self.storedBytes = 0
self.uploadSequenceToken = 0 # I'm guessing this is token needed for sequenceToken by put_events
self.events = []
self.__class__._log_ids += 1
def to_describe_dict(self):
return {
"arn": self.arn,
"creationTime": self.creationTime,
"firstEventTimestamp": self.firstEventTimestamp,
"lastEventTimestamp": self.lastEventTimestamp,
"lastIngestionTime": self.lastIngestionTime,
"logStreamName": self.logStreamName,
"storedBytes": self.storedBytes,
"uploadSequenceToken": str(self.uploadSequenceToken),
}
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
# TODO: ensure sequence_token
# TODO: to be thread safe this would need a lock
self.lastIngestionTime = unix_time_millis()
# TODO: make this match AWS if possible
self.storedBytes += sum([len(log_event["message"]) for log_event in log_events])
self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events]
self.uploadSequenceToken += 1
return self.uploadSequenceToken
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
def filter_func(event):
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head)
back_token = next_token
if next_token is None:
next_token = 0
events_page = events[next_token: next_token + limit]
next_token += limit
if next_token >= len(self.events):
next_token = None
return events_page, back_token, next_token
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
def filter_func(event):
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
events = []
for event in sorted(filter(filter_func, self.events), key=lambda x: x.timestamp):
event_obj = event.to_filter_dict()
event_obj['logStreamName'] = self.logStreamName
events.append(event_obj)
return events
class LogGroup:
def __init__(self, region, name, tags):
self.name = name
self.region = region
self.tags = tags
self.streams = dict() # {name: LogStream}
def create_log_stream(self, log_stream_name):
assert log_stream_name not in self.streams
self.streams[log_stream_name] = LogStream(self.region, self.name, log_stream_name)
def delete_log_stream(self, log_stream_name):
assert log_stream_name in self.streams
del self.streams[log_stream_name]
def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by):
log_streams = [stream.to_describe_dict() for name, stream in self.streams.items() if name.startswith(log_stream_name_prefix)]
def sorter(stream):
return stream.name if order_by == 'logStreamName' else stream.lastEventTimestamp
if next_token is None:
next_token = 0
log_streams = sorted(log_streams, key=sorter, reverse=descending)
new_token = next_token + limit
log_streams_page = log_streams[next_token: new_token]
if new_token >= len(log_streams):
new_token = None
return log_streams_page, new_token
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
assert log_stream_name in self.streams
stream = self.streams[log_stream_name]
return stream.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
assert log_stream_name in self.streams
stream = self.streams[log_stream_name]
return stream.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
assert not filter_pattern # TODO: impl
streams = [stream for name, stream in self.streams.items() if not log_stream_names or name in log_stream_names]
events = []
for stream in streams:
events += stream.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
if interleaved:
events = sorted(events, key=lambda event: event.timestamp)
if next_token is None:
next_token = 0
events_page = events[next_token: next_token + limit]
next_token += limit
if next_token >= len(events):
next_token = None
searched_streams = [{"logStreamName": stream.logStreamName, "searchedCompletely": True} for stream in streams]
return events_page, next_token, searched_streams
class LogsBackend(BaseBackend):
def __init__(self, region_name):
self.region_name = region_name
self.groups = dict() # { logGroupName: LogGroup}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_log_group(self, log_group_name, tags):
assert log_group_name not in self.groups
self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags)
def ensure_log_group(self, log_group_name, tags):
if log_group_name in self.groups:
return
self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags)
def delete_log_group(self, log_group_name):
assert log_group_name in self.groups
del self.groups[log_group_name]
def create_log_stream(self, log_group_name, log_stream_name):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.create_log_stream(log_stream_name)
def delete_log_stream(self, log_group_name, log_stream_name):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.delete_log_stream(log_stream_name)
def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.describe_log_streams(descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by)
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
# TODO: add support for sequence_tokens
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
logs_backends = {region.name: LogsBackend(region.name) for region in boto.logs.regions()}

114
moto/logs/responses.py Normal file
View File

@ -0,0 +1,114 @@
from moto.core.responses import BaseResponse
from .models import logs_backends
import json
# See http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/Welcome.html
class LogsResponse(BaseResponse):
@property
def logs_backend(self):
return logs_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)
def create_log_group(self):
log_group_name = self._get_param('logGroupName')
tags = self._get_param('tags')
assert 1 <= len(log_group_name) <= 512 # TODO: assert pattern
self.logs_backend.create_log_group(log_group_name, tags)
return ''
def delete_log_group(self):
log_group_name = self._get_param('logGroupName')
self.logs_backend.delete_log_group(log_group_name)
return ''
def create_log_stream(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
self.logs_backend.create_log_stream(log_group_name, log_stream_name)
return ''
def delete_log_stream(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
self.logs_backend.delete_log_stream(log_group_name, log_stream_name)
return ''
def describe_log_streams(self):
log_group_name = self._get_param('logGroupName')
log_stream_name_prefix = self._get_param('logStreamNamePrefix')
descending = self._get_param('descending', False)
limit = self._get_param('limit', 50)
assert limit <= 50
next_token = self._get_param('nextToken')
order_by = self._get_param('orderBy', 'LogStreamName')
assert order_by in {'LogStreamName', 'LastEventTime'}
if order_by == 'LastEventTime':
assert not log_stream_name_prefix
streams, next_token = self.logs_backend.describe_log_streams(
descending, limit, log_group_name, log_stream_name_prefix,
next_token, order_by)
return json.dumps({
"logStreams": streams,
"nextToken": next_token
})
def put_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
log_events = self._get_param('logEvents')
sequence_token = self._get_param('sequenceToken')
next_sequence_token = self.logs_backend.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
return json.dumps({'nextSequenceToken': next_sequence_token})
def get_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
start_time = self._get_param('startTime')
end_time = self._get_param("endTime")
limit = self._get_param('limit', 10000)
assert limit <= 10000
next_token = self._get_param('nextToken')
start_from_head = self._get_param('startFromHead')
events, next_backward_token, next_foward_token = \
self.logs_backend.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
return json.dumps({
"events": events,
"nextBackwardToken": next_backward_token,
"nextForwardToken": next_foward_token
})
def filter_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_names = self._get_param('logStreamNames', [])
start_time = self._get_param('startTime')
# impl, see: http://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html
filter_pattern = self._get_param('filterPattern')
interleaved = self._get_param('interleaved', False)
end_time = self._get_param("endTime")
limit = self._get_param('limit', 10000)
assert limit <= 10000
next_token = self._get_param('nextToken')
events, next_token, searched_streams = self.logs_backend.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
return json.dumps({
"events": events,
"nextToken": next_token,
"searchedLogStreams": searched_streams
})

9
moto/logs/urls.py Normal file
View File

@ -0,0 +1,9 @@
from .responses import LogsResponse
url_bases = [
"https?://logs.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': LogsResponse.dispatch,
}

View File

@ -422,11 +422,11 @@ class OpsWorksBackend(BaseBackend):
stackid = kwargs['stack_id']
if stackid not in self.stacks:
raise ResourceNotFoundException(stackid)
if name in [l.name for l in self.layers.values()]:
if name in [l.name for l in self.stacks[stackid].layers]:
raise ValidationException(
'There is already a layer named "{0}" '
'for this stack'.format(name))
if shortname in [l.shortname for l in self.layers.values()]:
if shortname in [l.shortname for l in self.stacks[stackid].layers]:
raise ValidationException(
'There is already a layer with shortname "{0}" '
'for this stack'.format(shortname))

6
moto/polly/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import polly_backends
from ..core.models import base_decorator
polly_backend = polly_backends['us-east-1']
mock_polly = base_decorator(polly_backends)

114
moto/polly/models.py Normal file
View File

@ -0,0 +1,114 @@
from __future__ import unicode_literals
from xml.etree import ElementTree as ET
import datetime
import boto3
from moto.core import BaseBackend, BaseModel
from .resources import VOICE_DATA
from .utils import make_arn_for_lexicon
DEFAULT_ACCOUNT_ID = 123456789012
class Lexicon(BaseModel):
def __init__(self, name, content, region_name):
self.name = name
self.content = content
self.size = 0
self.alphabet = None
self.last_modified = None
self.language_code = None
self.lexemes_count = 0
self.arn = make_arn_for_lexicon(DEFAULT_ACCOUNT_ID, name, region_name)
self.update()
def update(self, content=None):
if content is not None:
self.content = content
# Probably a very naive approach, but it'll do for now.
try:
root = ET.fromstring(self.content)
self.size = len(self.content)
self.last_modified = int((datetime.datetime.now() -
datetime.datetime(1970, 1, 1)).total_seconds())
self.lexemes_count = len(root.findall('.'))
for key, value in root.attrib.items():
if key.endswith('alphabet'):
self.alphabet = value
elif key.endswith('lang'):
self.language_code = value
except Exception as err:
raise ValueError('Failure parsing XML: {0}'.format(err))
def to_dict(self):
return {
'Attributes': {
'Alphabet': self.alphabet,
'LanguageCode': self.language_code,
'LastModified': self.last_modified,
'LexemesCount': self.lexemes_count,
'LexiconArn': self.arn,
'Size': self.size
}
}
def __repr__(self):
return '<Lexicon {0}>'.format(self.name)
class PollyBackend(BaseBackend):
def __init__(self, region_name=None):
super(PollyBackend, self).__init__()
self.region_name = region_name
self._lexicons = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def describe_voices(self, language_code, next_token):
if language_code is None:
return VOICE_DATA
return [item for item in VOICE_DATA if item['LanguageCode'] == language_code]
def delete_lexicon(self, name):
# implement here
del self._lexicons[name]
def get_lexicon(self, name):
# Raises KeyError
return self._lexicons[name]
def list_lexicons(self, next_token):
result = []
for name, lexicon in self._lexicons.items():
lexicon_dict = lexicon.to_dict()
lexicon_dict['Name'] = name
result.append(lexicon_dict)
return result
def put_lexicon(self, name, content):
# If lexicon content is bad, it will raise ValueError
if name in self._lexicons:
# Regenerated all the stats from the XML
# but keeps the ARN
self._lexicons.update(content)
else:
lexicon = Lexicon(name, content, region_name=self.region_name)
self._lexicons[name] = lexicon
available_regions = boto3.session.Session().get_available_regions("polly")
polly_backends = {region: PollyBackend(region_name=region) for region in available_regions}

63
moto/polly/resources.py Normal file
View File

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
VOICE_DATA = [
{'Id': 'Joanna', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Joanna'},
{'Id': 'Mizuki', 'LanguageCode': 'ja-JP', 'LanguageName': 'Japanese', 'Gender': 'Female', 'Name': 'Mizuki'},
{'Id': 'Filiz', 'LanguageCode': 'tr-TR', 'LanguageName': 'Turkish', 'Gender': 'Female', 'Name': 'Filiz'},
{'Id': 'Astrid', 'LanguageCode': 'sv-SE', 'LanguageName': 'Swedish', 'Gender': 'Female', 'Name': 'Astrid'},
{'Id': 'Tatyana', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Female', 'Name': 'Tatyana'},
{'Id': 'Maxim', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Male', 'Name': 'Maxim'},
{'Id': 'Carmen', 'LanguageCode': 'ro-RO', 'LanguageName': 'Romanian', 'Gender': 'Female', 'Name': 'Carmen'},
{'Id': 'Ines', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Female', 'Name': 'Inês'},
{'Id': 'Cristiano', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Male', 'Name': 'Cristiano'},
{'Id': 'Vitoria', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Female', 'Name': 'Vitória'},
{'Id': 'Ricardo', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Male', 'Name': 'Ricardo'},
{'Id': 'Maja', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Maja'},
{'Id': 'Jan', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jan'},
{'Id': 'Ewa', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Ewa'},
{'Id': 'Ruben', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Male', 'Name': 'Ruben'},
{'Id': 'Lotte', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Female', 'Name': 'Lotte'},
{'Id': 'Liv', 'LanguageCode': 'nb-NO', 'LanguageName': 'Norwegian', 'Gender': 'Female', 'Name': 'Liv'},
{'Id': 'Giorgio', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Male', 'Name': 'Giorgio'},
{'Id': 'Carla', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Female', 'Name': 'Carla'},
{'Id': 'Karl', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Male', 'Name': 'Karl'},
{'Id': 'Dora', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Female', 'Name': 'Dóra'},
{'Id': 'Mathieu', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Male', 'Name': 'Mathieu'},
{'Id': 'Celine', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Female', 'Name': 'Céline'},
{'Id': 'Chantal', 'LanguageCode': 'fr-CA', 'LanguageName': 'Canadian French', 'Gender': 'Female', 'Name': 'Chantal'},
{'Id': 'Penelope', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Female', 'Name': 'Penélope'},
{'Id': 'Miguel', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Male', 'Name': 'Miguel'},
{'Id': 'Enrique', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Male', 'Name': 'Enrique'},
{'Id': 'Conchita', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Female', 'Name': 'Conchita'},
{'Id': 'Geraint', 'LanguageCode': 'en-GB-WLS', 'LanguageName': 'Welsh English', 'Gender': 'Male', 'Name': 'Geraint'},
{'Id': 'Salli', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Salli'},
{'Id': 'Kimberly', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kimberly'},
{'Id': 'Kendra', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kendra'},
{'Id': 'Justin', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Justin'},
{'Id': 'Joey', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Joey'},
{'Id': 'Ivy', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Ivy'},
{'Id': 'Raveena', 'LanguageCode': 'en-IN', 'LanguageName': 'Indian English', 'Gender': 'Female', 'Name': 'Raveena'},
{'Id': 'Emma', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Emma'},
{'Id': 'Brian', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Male', 'Name': 'Brian'},
{'Id': 'Amy', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Amy'},
{'Id': 'Russell', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Male', 'Name': 'Russell'},
{'Id': 'Nicole', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Female', 'Name': 'Nicole'},
{'Id': 'Vicki', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Vicki'},
{'Id': 'Marlene', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Marlene'},
{'Id': 'Hans', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Male', 'Name': 'Hans'},
{'Id': 'Naja', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Female', 'Name': 'Naja'},
{'Id': 'Mads', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Male', 'Name': 'Mads'},
{'Id': 'Gwyneth', 'LanguageCode': 'cy-GB', 'LanguageName': 'Welsh', 'Gender': 'Female', 'Name': 'Gwyneth'},
{'Id': 'Jacek', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jacek'}
]
# {...} is also shorthand set syntax
LANGUAGE_CODES = {'cy-GB', 'da-DK', 'de-DE', 'en-AU', 'en-GB', 'en-GB-WLS', 'en-IN', 'en-US', 'es-ES', 'es-US',
'fr-CA', 'fr-FR', 'is-IS', 'it-IT', 'ja-JP', 'nb-NO', 'nl-NL', 'pl-PL', 'pt-BR', 'pt-PT', 'ro-RO',
'ru-RU', 'sv-SE', 'tr-TR'}
VOICE_IDS = {'Geraint', 'Gwyneth', 'Mads', 'Naja', 'Hans', 'Marlene', 'Nicole', 'Russell', 'Amy', 'Brian', 'Emma',
'Raveena', 'Ivy', 'Joanna', 'Joey', 'Justin', 'Kendra', 'Kimberly', 'Salli', 'Conchita', 'Enrique',
'Miguel', 'Penelope', 'Chantal', 'Celine', 'Mathieu', 'Dora', 'Karl', 'Carla', 'Giorgio', 'Mizuki',
'Liv', 'Lotte', 'Ruben', 'Ewa', 'Jacek', 'Jan', 'Maja', 'Ricardo', 'Vitoria', 'Cristiano', 'Ines',
'Carmen', 'Maxim', 'Tatyana', 'Astrid', 'Filiz'}

188
moto/polly/responses.py Normal file
View File

@ -0,0 +1,188 @@
from __future__ import unicode_literals
import json
import re
from six.moves.urllib.parse import urlsplit
from moto.core.responses import BaseResponse
from .models import polly_backends
from .resources import LANGUAGE_CODES, VOICE_IDS
LEXICON_NAME_REGEX = re.compile(r'^[0-9A-Za-z]{1,20}$')
class PollyResponse(BaseResponse):
@property
def polly_backend(self):
return polly_backends[self.region]
@property
def json(self):
if not hasattr(self, '_json'):
self._json = json.loads(self.body)
return self._json
def _error(self, code, message):
return json.dumps({'__type': code, 'message': message}), dict(status=400)
def _get_action(self):
# Amazon is now naming things /v1/api_name
url_parts = urlsplit(self.uri).path.lstrip('/').split('/')
# [0] = 'v1'
return url_parts[1]
# DescribeVoices
def voices(self):
language_code = self._get_param('LanguageCode')
next_token = self._get_param('NextToken')
if language_code is not None and language_code not in LANGUAGE_CODES:
msg = "1 validation error detected: Value '{0}' at 'languageCode' failed to satisfy constraint: " \
"Member must satisfy enum value set: [{1}]".format(language_code, ', '.join(LANGUAGE_CODES))
return msg, dict(status=400)
voices = self.polly_backend.describe_voices(language_code, next_token)
return json.dumps({'Voices': voices})
def lexicons(self):
# Dish out requests based on methods
# anything after the /v1/lexicons/
args = urlsplit(self.uri).path.lstrip('/').split('/')[2:]
if self.method == 'GET':
if len(args) == 0:
return self._get_lexicons_list()
else:
return self._get_lexicon(*args)
elif self.method == 'PUT':
return self._put_lexicons(*args)
elif self.method == 'DELETE':
return self._delete_lexicon(*args)
return self._error('InvalidAction', 'Bad route')
# PutLexicon
def _put_lexicons(self, lexicon_name):
if LEXICON_NAME_REGEX.match(lexicon_name) is None:
return self._error('InvalidParameterValue', 'Lexicon name must match [0-9A-Za-z]{1,20}')
if 'Content' not in self.json:
return self._error('MissingParameter', 'Content is missing from the body')
self.polly_backend.put_lexicon(lexicon_name, self.json['Content'])
return ''
# ListLexicons
def _get_lexicons_list(self):
next_token = self._get_param('NextToken')
result = {
'Lexicons': self.polly_backend.list_lexicons(next_token)
}
return json.dumps(result)
# GetLexicon
def _get_lexicon(self, lexicon_name):
try:
lexicon = self.polly_backend.get_lexicon(lexicon_name)
except KeyError:
return self._error('LexiconNotFoundException', 'Lexicon not found')
result = {
'Lexicon': {
'Name': lexicon_name,
'Content': lexicon.content
},
'LexiconAttributes': lexicon.to_dict()['Attributes']
}
return json.dumps(result)
# DeleteLexicon
def _delete_lexicon(self, lexicon_name):
try:
self.polly_backend.delete_lexicon(lexicon_name)
except KeyError:
return self._error('LexiconNotFoundException', 'Lexicon not found')
return ''
# SynthesizeSpeech
def speech(self):
# Sanity check params
args = {
'lexicon_names': None,
'sample_rate': 22050,
'speech_marks': None,
'text': None,
'text_type': 'text'
}
if 'LexiconNames' in self.json:
for lex in self.json['LexiconNames']:
try:
self.polly_backend.get_lexicon(lex)
except KeyError:
return self._error('LexiconNotFoundException', 'Lexicon not found')
args['lexicon_names'] = self.json['LexiconNames']
if 'OutputFormat' not in self.json:
return self._error('MissingParameter', 'Missing parameter OutputFormat')
if self.json['OutputFormat'] not in ('json', 'mp3', 'ogg_vorbis', 'pcm'):
return self._error('InvalidParameterValue', 'Not one of json, mp3, ogg_vorbis, pcm')
args['output_format'] = self.json['OutputFormat']
if 'SampleRate' in self.json:
sample_rate = int(self.json['SampleRate'])
if sample_rate not in (8000, 16000, 22050):
return self._error('InvalidSampleRateException', 'The specified sample rate is not valid.')
args['sample_rate'] = sample_rate
if 'SpeechMarkTypes' in self.json:
for value in self.json['SpeechMarkTypes']:
if value not in ('sentance', 'ssml', 'viseme', 'word'):
return self._error('InvalidParameterValue', 'Not one of sentance, ssml, viseme, word')
args['speech_marks'] = self.json['SpeechMarkTypes']
if 'Text' not in self.json:
return self._error('MissingParameter', 'Missing parameter Text')
args['text'] = self.json['Text']
if 'TextType' in self.json:
if self.json['TextType'] not in ('ssml', 'text'):
return self._error('InvalidParameterValue', 'Not one of ssml, text')
args['text_type'] = self.json['TextType']
if 'VoiceId' not in self.json:
return self._error('MissingParameter', 'Missing parameter VoiceId')
if self.json['VoiceId'] not in VOICE_IDS:
return self._error('InvalidParameterValue', 'Not one of {0}'.format(', '.join(VOICE_IDS)))
args['voice_id'] = self.json['VoiceId']
# More validation
if len(args['text']) > 3000:
return self._error('TextLengthExceededException', 'Text too long')
if args['speech_marks'] is not None and args['output_format'] != 'json':
return self._error('MarksNotSupportedForFormatException', 'OutputFormat must be json')
if args['speech_marks'] is not None and args['text_type'] == 'text':
return self._error('SsmlMarksNotSupportedForTextTypeException', 'TextType must be ssml')
content_type = 'audio/json'
if args['output_format'] == 'mp3':
content_type = 'audio/mpeg'
elif args['output_format'] == 'ogg_vorbis':
content_type = 'audio/ogg'
elif args['output_format'] == 'pcm':
content_type = 'audio/pcm'
headers = {'Content-Type': content_type}
return '\x00\x00\x00\x00\x00\x00\x00\x00', headers

13
moto/polly/urls.py Normal file
View File

@ -0,0 +1,13 @@
from __future__ import unicode_literals
from .responses import PollyResponse
url_bases = [
"https?://polly.(.+).amazonaws.com",
]
url_paths = {
'{0}/v1/voices': PollyResponse.dispatch,
'{0}/v1/lexicons/(?P<lexicon>[^/]+)': PollyResponse.dispatch,
'{0}/v1/lexicons': PollyResponse.dispatch,
'{0}/v1/speech': PollyResponse.dispatch,
}

5
moto/polly/utils.py Normal file
View File

@ -0,0 +1,5 @@
from __future__ import unicode_literals
def make_arn_for_lexicon(account_id, name, region_name):
return "arn:aws:polly:{0}:{1}:lexicon/{2}".format(region_name, account_id, name)

View File

@ -58,3 +58,36 @@ class DBParameterGroupNotFoundError(RDSClientError):
super(DBParameterGroupNotFoundError, self).__init__(
'DBParameterGroupNotFound',
'DB Parameter Group {0} not found.'.format(db_parameter_group_name))
class InvalidDBClusterStateFaultError(RDSClientError):
def __init__(self, database_identifier):
super(InvalidDBClusterStateFaultError, self).__init__(
'InvalidDBClusterStateFault',
'Invalid DB type, when trying to perform StopDBInstance on {0}e. See AWS RDS documentation on rds.stop_db_instance'.format(database_identifier))
class InvalidDBInstanceStateError(RDSClientError):
def __init__(self, database_identifier, istate):
estate = "in available state" if istate == 'stop' else "stopped, it cannot be started"
super(InvalidDBInstanceStateError, self).__init__(
'InvalidDBInstanceState',
'Instance {} is not {}.'.format(database_identifier, estate))
class SnapshotQuotaExceededError(RDSClientError):
def __init__(self):
super(SnapshotQuotaExceededError, self).__init__(
'SnapshotQuotaExceeded',
'The request cannot be processed because it would exceed the maximum number of snapshots.')
class DBSnapshotAlreadyExistsError(RDSClientError):
def __init__(self, database_snapshot_identifier):
super(DBSnapshotAlreadyExistsError, self).__init__(
'DBSnapshotAlreadyExists',
'Cannot create the snapshot because a snapshot with the identifier {} already exists.'.format(database_snapshot_identifier))

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import copy
import datetime
import os
from collections import defaultdict
import boto.rds2
@ -18,7 +19,11 @@ from .exceptions import (RDSClientError,
DBSnapshotNotFoundError,
DBSecurityGroupNotFoundError,
DBSubnetGroupNotFoundError,
DBParameterGroupNotFoundError)
DBParameterGroupNotFoundError,
InvalidDBClusterStateFaultError,
InvalidDBInstanceStateError,
SnapshotQuotaExceededError,
DBSnapshotAlreadyExistsError)
class Database(BaseModel):
@ -674,10 +679,14 @@ class RDS2Backend(BaseBackend):
self.databases[database_id] = database
return database
def create_snapshot(self, db_instance_identifier, db_snapshot_identifier, tags):
def create_snapshot(self, db_instance_identifier, db_snapshot_identifier, tags=None):
database = self.databases.get(db_instance_identifier)
if not database:
raise DBInstanceNotFoundError(db_instance_identifier)
if db_snapshot_identifier in self.snapshots:
raise DBSnapshotAlreadyExistsError(db_snapshot_identifier)
if len(self.snapshots) >= int(os.environ.get('MOTO_RDS_SNAPSHOT_LIMIT', '100')):
raise SnapshotQuotaExceededError()
snapshot = Snapshot(database, db_snapshot_identifier, tags)
self.snapshots[db_snapshot_identifier] = snapshot
return snapshot
@ -733,6 +742,27 @@ class RDS2Backend(BaseBackend):
database = self.describe_databases(db_instance_identifier)[0]
return database
def stop_database(self, db_instance_identifier, db_snapshot_identifier=None):
database = self.describe_databases(db_instance_identifier)[0]
# todo: certain rds types not allowed to be stopped at this time.
if database.is_replica or database.multi_az:
# todo: more db types not supported by stop/start instance api
raise InvalidDBClusterStateFaultError(db_instance_identifier)
if database.status != 'available':
raise InvalidDBInstanceStateError(db_instance_identifier, 'stop')
if db_snapshot_identifier:
self.create_snapshot(db_instance_identifier, db_snapshot_identifier)
database.status = 'shutdown'
return database
def start_database(self, db_instance_identifier):
database = self.describe_databases(db_instance_identifier)[0]
# todo: bunch of different error messages to be generated from this api call
if database.status != 'shutdown':
raise InvalidDBInstanceStateError(db_instance_identifier, 'start')
database.status = 'available'
return database
def find_db_from_id(self, db_id):
if self.arn_regex.match(db_id):
arn_breakdown = db_id.split(':')

View File

@ -23,6 +23,7 @@ class RDS2Response(BaseResponse):
"db_instance_identifier": self._get_param('DBInstanceIdentifier'),
"db_name": self._get_param("DBName"),
"db_parameter_group_name": self._get_param("DBParameterGroupName"),
"db_snapshot_identifier": self._get_param('DBSnapshotIdentifier'),
"db_subnet_group_name": self._get_param("DBSubnetGroupName"),
"engine": self._get_param("Engine"),
"engine_version": self._get_param("EngineVersion"),
@ -193,6 +194,19 @@ class RDS2Response(BaseResponse):
template = self.response_template(REMOVE_TAGS_FROM_RESOURCE_TEMPLATE)
return template.render()
def stop_db_instance(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier')
db_snapshot_identifier = self._get_param('DBSnapshotIdentifier')
database = self.backend.stop_database(db_instance_identifier, db_snapshot_identifier)
template = self.response_template(STOP_DATABASE_TEMPLATE)
return template.render(database=database)
def start_db_instance(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier')
database = self.backend.start_database(db_instance_identifier)
template = self.response_template(START_DATABASE_TEMPLATE)
return template.render(database=database)
def create_db_security_group(self):
group_name = self._get_param('DBSecurityGroupName')
description = self._get_param('DBSecurityGroupDescription')
@ -410,6 +424,23 @@ REBOOT_DATABASE_TEMPLATE = """<RebootDBInstanceResponse xmlns="http://rds.amazon
</ResponseMetadata>
</RebootDBInstanceResponse>"""
START_DATABASE_TEMPLATE = """<StartDBInstanceResponse xmlns="http://rds.amazonaws.com/doc/2014-10-31/">
<StartDBInstanceResult>
{{ database.to_xml() }}
</StartDBInstanceResult>
<ResponseMetadata>
<RequestId>523e3218-afc7-11c3-90f5-f90431260ab9</RequestId>
</ResponseMetadata>
</StartDBInstanceResponse>"""
STOP_DATABASE_TEMPLATE = """<StopDBInstanceResponse xmlns="http://rds.amazonaws.com/doc/2014-10-31/">
<StopDBInstanceResult>
{{ database.to_xml() }}
</StopDBInstanceResult>
<ResponseMetadata>
<RequestId>523e3218-afc7-11c3-90f5-f90431260ab8</RequestId>
</ResponseMetadata>
</StopDBInstanceResponse>"""
DELETE_DATABASE_TEMPLATE = """<DeleteDBInstanceResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/">
<DeleteDBInstanceResult>

View File

@ -71,3 +71,25 @@ class ClusterSnapshotAlreadyExistsError(RedshiftClientError):
'ClusterSnapshotAlreadyExists',
"Cannot create the snapshot because a snapshot with the "
"identifier {0} already exists".format(snapshot_identifier))
class InvalidParameterValueError(RedshiftClientError):
def __init__(self, message):
super(InvalidParameterValueError, self).__init__(
'InvalidParameterValue',
message)
class ResourceNotFoundFaultError(RedshiftClientError):
code = 404
def __init__(self, resource_type=None, resource_name=None, message=None):
if resource_type and not resource_name:
msg = "resource of type '{0}' not found.".format(resource_type)
else:
msg = "{0} ({1}) not found.".format(resource_type, resource_name)
if message:
msg = message
super(ResourceNotFoundFaultError, self).__init__(
'ResourceNotFoundFault', msg)

View File

@ -15,11 +15,51 @@ from .exceptions import (
ClusterSnapshotAlreadyExistsError,
ClusterSnapshotNotFoundError,
ClusterSubnetGroupNotFoundError,
InvalidParameterValueError,
InvalidSubnetError,
ResourceNotFoundFaultError
)
class Cluster(BaseModel):
ACCOUNT_ID = 123456789012
class TaggableResourceMixin(object):
resource_type = None
def __init__(self, region_name, tags):
self.region = region_name
self.tags = tags or []
@property
def resource_id(self):
return None
@property
def arn(self):
return "arn:aws:redshift:{region}:{account_id}:{resource_type}:{resource_id}".format(
region=self.region,
account_id=ACCOUNT_ID,
resource_type=self.resource_type,
resource_id=self.resource_id)
def create_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags
if tag_set['Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def delete_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags
if tag_set['Key'] not in tag_keys]
return self.tags
class Cluster(TaggableResourceMixin, BaseModel):
resource_type = 'cluster'
def __init__(self, redshift_backend, cluster_identifier, node_type, master_username,
master_user_password, db_name, cluster_type, cluster_security_groups,
@ -27,7 +67,8 @@ class Cluster(BaseModel):
preferred_maintenance_window, cluster_parameter_group_name,
automated_snapshot_retention_period, port, cluster_version,
allow_version_upgrade, number_of_nodes, publicly_accessible,
encrypted, region):
encrypted, region_name, tags=None):
super(Cluster, self).__init__(region_name, tags)
self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier
self.status = 'available'
@ -57,13 +98,12 @@ class Cluster(BaseModel):
else:
self.cluster_security_groups = ["Default"]
self.region = region
if availability_zone:
self.availability_zone = availability_zone
else:
# This could probably be smarter, but there doesn't appear to be a
# way to pull AZs for a region in boto
self.availability_zone = region + "a"
self.availability_zone = region_name + "a"
if cluster_type == 'single-node':
self.number_of_nodes = 1
@ -106,7 +146,7 @@ class Cluster(BaseModel):
number_of_nodes=properties.get('NumberOfNodes'),
publicly_accessible=properties.get("PubliclyAccessible"),
encrypted=properties.get("Encrypted"),
region=region_name,
region_name=region_name,
)
return cluster
@ -149,6 +189,10 @@ class Cluster(BaseModel):
if parameter_group.cluster_parameter_group_name in self.cluster_parameter_group_name
]
@property
def resource_id(self):
return self.cluster_identifier
def to_json(self):
return {
"MasterUsername": self.master_username,
@ -180,18 +224,21 @@ class Cluster(BaseModel):
"ClusterIdentifier": self.cluster_identifier,
"AllowVersionUpgrade": self.allow_version_upgrade,
"Endpoint": {
"Address": '{}.{}.redshift.amazonaws.com'.format(
self.cluster_identifier,
self.region),
"Address": self.endpoint,
"Port": self.port
},
"PendingModifiedValues": []
"PendingModifiedValues": [],
"Tags": self.tags
}
class SubnetGroup(BaseModel):
class SubnetGroup(TaggableResourceMixin, BaseModel):
def __init__(self, ec2_backend, cluster_subnet_group_name, description, subnet_ids):
resource_type = 'subnetgroup'
def __init__(self, ec2_backend, cluster_subnet_group_name, description, subnet_ids,
region_name, tags=None):
super(SubnetGroup, self).__init__(region_name, tags)
self.ec2_backend = ec2_backend
self.cluster_subnet_group_name = cluster_subnet_group_name
self.description = description
@ -208,6 +255,7 @@ class SubnetGroup(BaseModel):
cluster_subnet_group_name=resource_name,
description=properties.get("Description"),
subnet_ids=properties.get("SubnetIds", []),
region_name=region_name
)
return subnet_group
@ -219,6 +267,10 @@ class SubnetGroup(BaseModel):
def vpc_id(self):
return self.subnets[0].vpc_id
@property
def resource_id(self):
return self.cluster_subnet_group_name
def to_json(self):
return {
"VpcId": self.vpc_id,
@ -232,27 +284,39 @@ class SubnetGroup(BaseModel):
"Name": subnet.availability_zone
},
} for subnet in self.subnets],
"Tags": self.tags
}
class SecurityGroup(BaseModel):
class SecurityGroup(TaggableResourceMixin, BaseModel):
def __init__(self, cluster_security_group_name, description):
resource_type = 'securitygroup'
def __init__(self, cluster_security_group_name, description, region_name, tags=None):
super(SecurityGroup, self).__init__(region_name, tags)
self.cluster_security_group_name = cluster_security_group_name
self.description = description
@property
def resource_id(self):
return self.cluster_security_group_name
def to_json(self):
return {
"EC2SecurityGroups": [],
"IPRanges": [],
"Description": self.description,
"ClusterSecurityGroupName": self.cluster_security_group_name,
"Tags": self.tags
}
class ParameterGroup(BaseModel):
class ParameterGroup(TaggableResourceMixin, BaseModel):
def __init__(self, cluster_parameter_group_name, group_family, description):
resource_type = 'parametergroup'
def __init__(self, cluster_parameter_group_name, group_family, description, region_name, tags=None):
super(ParameterGroup, self).__init__(region_name, tags)
self.cluster_parameter_group_name = cluster_parameter_group_name
self.group_family = group_family
self.description = description
@ -266,34 +330,41 @@ class ParameterGroup(BaseModel):
cluster_parameter_group_name=resource_name,
description=properties.get("Description"),
group_family=properties.get("ParameterGroupFamily"),
region_name=region_name
)
return parameter_group
@property
def resource_id(self):
return self.cluster_parameter_group_name
def to_json(self):
return {
"ParameterGroupFamily": self.group_family,
"Description": self.description,
"ParameterGroupName": self.cluster_parameter_group_name,
"Tags": self.tags
}
class Snapshot(BaseModel):
class Snapshot(TaggableResourceMixin, BaseModel):
def __init__(self, cluster, snapshot_identifier, tags=None):
resource_type = 'snapshot'
def __init__(self, cluster, snapshot_identifier, region_name, tags=None):
super(Snapshot, self).__init__(region_name, tags)
self.cluster = copy.copy(cluster)
self.snapshot_identifier = snapshot_identifier
self.snapshot_type = 'manual'
self.status = 'available'
self.tags = tags or []
self.create_time = iso_8601_datetime_with_milliseconds(
datetime.datetime.now())
@property
def arn(self):
return "arn:aws:redshift:{0}:1234567890:snapshot:{1}/{2}".format(
self.cluster.region,
self.cluster.cluster_identifier,
self.snapshot_identifier)
def resource_id(self):
return "{cluster_id}/{snapshot_id}".format(
cluster_id=self.cluster.cluster_identifier,
snapshot_id=self.snapshot_identifier)
def to_json(self):
return {
@ -315,26 +386,36 @@ class Snapshot(BaseModel):
class RedshiftBackend(BaseBackend):
def __init__(self, ec2_backend):
def __init__(self, ec2_backend, region_name):
self.region = region_name
self.clusters = {}
self.subnet_groups = {}
self.security_groups = {
"Default": SecurityGroup("Default", "Default Redshift Security Group")
"Default": SecurityGroup("Default", "Default Redshift Security Group", self.region)
}
self.parameter_groups = {
"default.redshift-1.0": ParameterGroup(
"default.redshift-1.0",
"redshift-1.0",
"Default Redshift parameter group",
self.region
)
}
self.ec2_backend = ec2_backend
self.snapshots = OrderedDict()
self.RESOURCE_TYPE_MAP = {
'cluster': self.clusters,
'parametergroup': self.parameter_groups,
'securitygroup': self.security_groups,
'snapshot': self.snapshots,
'subnetgroup': self.subnet_groups
}
def reset(self):
ec2_backend = self.ec2_backend
region_name = self.region
self.__dict__ = {}
self.__init__(ec2_backend)
self.__init__(ec2_backend, region_name)
def create_cluster(self, **cluster_kwargs):
cluster_identifier = cluster_kwargs['cluster_identifier']
@ -373,9 +454,10 @@ class RedshiftBackend(BaseBackend):
return self.clusters.pop(cluster_identifier)
raise ClusterNotFoundError(cluster_identifier)
def create_cluster_subnet_group(self, cluster_subnet_group_name, description, subnet_ids):
def create_cluster_subnet_group(self, cluster_subnet_group_name, description, subnet_ids,
region_name, tags=None):
subnet_group = SubnetGroup(
self.ec2_backend, cluster_subnet_group_name, description, subnet_ids)
self.ec2_backend, cluster_subnet_group_name, description, subnet_ids, region_name, tags)
self.subnet_groups[cluster_subnet_group_name] = subnet_group
return subnet_group
@ -393,9 +475,9 @@ class RedshiftBackend(BaseBackend):
return self.subnet_groups.pop(subnet_identifier)
raise ClusterSubnetGroupNotFoundError(subnet_identifier)
def create_cluster_security_group(self, cluster_security_group_name, description):
def create_cluster_security_group(self, cluster_security_group_name, description, region_name, tags=None):
security_group = SecurityGroup(
cluster_security_group_name, description)
cluster_security_group_name, description, region_name, tags)
self.security_groups[cluster_security_group_name] = security_group
return security_group
@ -414,9 +496,9 @@ class RedshiftBackend(BaseBackend):
raise ClusterSecurityGroupNotFoundError(security_group_identifier)
def create_cluster_parameter_group(self, cluster_parameter_group_name,
group_family, description):
group_family, description, region_name, tags=None):
parameter_group = ParameterGroup(
cluster_parameter_group_name, group_family, description)
cluster_parameter_group_name, group_family, description, region_name, tags)
self.parameter_groups[cluster_parameter_group_name] = parameter_group
return parameter_group
@ -435,17 +517,17 @@ class RedshiftBackend(BaseBackend):
return self.parameter_groups.pop(parameter_group_name)
raise ClusterParameterGroupNotFoundError(parameter_group_name)
def create_snapshot(self, cluster_identifier, snapshot_identifier, tags):
def create_cluster_snapshot(self, cluster_identifier, snapshot_identifier, region_name, tags):
cluster = self.clusters.get(cluster_identifier)
if not cluster:
raise ClusterNotFoundError(cluster_identifier)
if self.snapshots.get(snapshot_identifier) is not None:
raise ClusterSnapshotAlreadyExistsError(snapshot_identifier)
snapshot = Snapshot(cluster, snapshot_identifier, tags)
snapshot = Snapshot(cluster, snapshot_identifier, region_name, tags)
self.snapshots[snapshot_identifier] = snapshot
return snapshot
def describe_snapshots(self, cluster_identifier, snapshot_identifier):
def describe_cluster_snapshots(self, cluster_identifier=None, snapshot_identifier=None):
if cluster_identifier:
for snapshot in self.snapshots.values():
if snapshot.cluster.cluster_identifier == cluster_identifier:
@ -459,7 +541,7 @@ class RedshiftBackend(BaseBackend):
return self.snapshots.values()
def delete_snapshot(self, snapshot_identifier):
def delete_cluster_snapshot(self, snapshot_identifier):
if snapshot_identifier not in self.snapshots:
raise ClusterSnapshotNotFoundError(snapshot_identifier)
@ -467,23 +549,105 @@ class RedshiftBackend(BaseBackend):
deleted_snapshot.status = 'deleted'
return deleted_snapshot
def describe_tags_for_resource_type(self, resource_type):
def restore_from_cluster_snapshot(self, **kwargs):
snapshot_identifier = kwargs.pop('snapshot_identifier')
snapshot = self.describe_cluster_snapshots(snapshot_identifier=snapshot_identifier)[0]
create_kwargs = {
"node_type": snapshot.cluster.node_type,
"master_username": snapshot.cluster.master_username,
"master_user_password": snapshot.cluster.master_user_password,
"db_name": snapshot.cluster.db_name,
"cluster_type": 'multi-node' if snapshot.cluster.number_of_nodes > 1 else 'single-node',
"availability_zone": snapshot.cluster.availability_zone,
"port": snapshot.cluster.port,
"cluster_version": snapshot.cluster.cluster_version,
"number_of_nodes": snapshot.cluster.number_of_nodes,
"encrypted": snapshot.cluster.encrypted,
"tags": snapshot.cluster.tags
}
create_kwargs.update(kwargs)
return self.create_cluster(**create_kwargs)
def _get_resource_from_arn(self, arn):
try:
arn_breakdown = arn.split(':')
resource_type = arn_breakdown[5]
if resource_type == 'snapshot':
resource_id = arn_breakdown[6].split('/')[1]
else:
resource_id = arn_breakdown[6]
except IndexError:
resource_type = resource_id = arn
resources = self.RESOURCE_TYPE_MAP.get(resource_type)
if resources is None:
message = (
"Tagging is not supported for this type of resource: '{0}' "
"(the ARN is potentially malformed, please check the ARN "
"documentation for more information)".format(resource_type))
raise ResourceNotFoundFaultError(message=message)
try:
resource = resources[resource_id]
except KeyError:
raise ResourceNotFoundFaultError(resource_type, resource_id)
else:
return resource
@staticmethod
def _describe_tags_for_resources(resources):
tagged_resources = []
if resource_type == 'Snapshot':
for snapshot in self.snapshots.values():
for tag in snapshot.tags:
data = {
'ResourceName': snapshot.arn,
'ResourceType': 'snapshot',
'Tag': {
'Key': tag['Key'],
'Value': tag['Value']
}
for resource in resources:
for tag in resource.tags:
data = {
'ResourceName': resource.arn,
'ResourceType': resource.resource_type,
'Tag': {
'Key': tag['Key'],
'Value': tag['Value']
}
tagged_resources.append(data)
}
tagged_resources.append(data)
return tagged_resources
def _describe_tags_for_resource_type(self, resource_type):
resources = self.RESOURCE_TYPE_MAP.get(resource_type)
if not resources:
raise ResourceNotFoundFaultError(resource_type=resource_type)
return self._describe_tags_for_resources(resources.values())
def _describe_tags_for_resource_name(self, resource_name):
resource = self._get_resource_from_arn(resource_name)
return self._describe_tags_for_resources([resource])
def create_tags(self, resource_name, tags):
resource = self._get_resource_from_arn(resource_name)
resource.create_tags(tags)
def describe_tags(self, resource_name, resource_type):
if resource_name and resource_type:
raise InvalidParameterValueError(
"You cannot filter a list of resources using an Amazon "
"Resource Name (ARN) and a resource type together in the "
"same request. Retry the request using either an ARN or "
"a resource type, but not both.")
if resource_type:
return self._describe_tags_for_resource_type(resource_type.lower())
if resource_name:
return self._describe_tags_for_resource_name(resource_name)
# If name and type are not specified, return all tagged resources.
# TODO: Implement aws marker pagination
tagged_resources = []
for resource_type in self.RESOURCE_TYPE_MAP:
try:
tagged_resources += self._describe_tags_for_resource_type(resource_type)
except ResourceNotFoundFaultError:
pass
return tagged_resources
def delete_tags(self, resource_name, tag_keys):
resource = self._get_resource_from_arn(resource_name)
resource.delete_tags(tag_keys)
redshift_backends = {}
for region in boto.redshift.regions():
redshift_backends[region.name] = RedshiftBackend(ec2_backends[region.name])
redshift_backends[region.name] = RedshiftBackend(ec2_backends[region.name], region.name)

View File

@ -57,6 +57,33 @@ class RedshiftResponse(BaseResponse):
count += 1
return unpacked_list
def unpack_list_params(self, label):
unpacked_list = list()
count = 1
while self._get_param('{0}.{1}'.format(label, count)):
unpacked_list.append(self._get_param(
'{0}.{1}'.format(label, count)))
count += 1
return unpacked_list
def _get_cluster_security_groups(self):
cluster_security_groups = self._get_multi_param('ClusterSecurityGroups.member')
if not cluster_security_groups:
cluster_security_groups = self._get_multi_param('ClusterSecurityGroups.ClusterSecurityGroupName')
return cluster_security_groups
def _get_vpc_security_group_ids(self):
vpc_security_group_ids = self._get_multi_param('VpcSecurityGroupIds.member')
if not vpc_security_group_ids:
vpc_security_group_ids = self._get_multi_param('VpcSecurityGroupIds.VpcSecurityGroupId')
return vpc_security_group_ids
def _get_subnet_ids(self):
subnet_ids = self._get_multi_param('SubnetIds.member')
if not subnet_ids:
subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier')
return subnet_ids
def create_cluster(self):
cluster_kwargs = {
"cluster_identifier": self._get_param('ClusterIdentifier'),
@ -65,8 +92,8 @@ class RedshiftResponse(BaseResponse):
"master_user_password": self._get_param('MasterUserPassword'),
"db_name": self._get_param('DBName'),
"cluster_type": self._get_param('ClusterType'),
"cluster_security_groups": self._get_multi_param('ClusterSecurityGroups.member'),
"vpc_security_group_ids": self._get_multi_param('VpcSecurityGroupIds.member'),
"cluster_security_groups": self._get_cluster_security_groups(),
"vpc_security_group_ids": self._get_vpc_security_group_ids(),
"cluster_subnet_group_name": self._get_param('ClusterSubnetGroupName'),
"availability_zone": self._get_param('AvailabilityZone'),
"preferred_maintenance_window": self._get_param('PreferredMaintenanceWindow'),
@ -78,7 +105,8 @@ class RedshiftResponse(BaseResponse):
"number_of_nodes": self._get_int_param('NumberOfNodes'),
"publicly_accessible": self._get_param("PubliclyAccessible"),
"encrypted": self._get_param("Encrypted"),
"region": self.region,
"region_name": self.region,
"tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
}
cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json()
cluster['ClusterStatus'] = 'creating'
@ -94,23 +122,8 @@ class RedshiftResponse(BaseResponse):
})
def restore_from_cluster_snapshot(self):
snapshot_identifier = self._get_param('SnapshotIdentifier')
snapshots = self.redshift_backend.describe_snapshots(
None,
snapshot_identifier)
snapshot = snapshots[0]
kwargs_from_snapshot = {
"node_type": snapshot.cluster.node_type,
"master_username": snapshot.cluster.master_username,
"master_user_password": snapshot.cluster.master_user_password,
"db_name": snapshot.cluster.db_name,
"cluster_type": 'multi-node' if snapshot.cluster.number_of_nodes > 1 else 'single-node',
"availability_zone": snapshot.cluster.availability_zone,
"port": snapshot.cluster.port,
"cluster_version": snapshot.cluster.cluster_version,
"number_of_nodes": snapshot.cluster.number_of_nodes,
}
kwargs_from_request = {
restore_kwargs = {
"snapshot_identifier": self._get_param('SnapshotIdentifier'),
"cluster_identifier": self._get_param('ClusterIdentifier'),
"port": self._get_int_param('Port'),
"availability_zone": self._get_param('AvailabilityZone'),
@ -121,20 +134,15 @@ class RedshiftResponse(BaseResponse):
"publicly_accessible": self._get_param("PubliclyAccessible"),
"cluster_parameter_group_name": self._get_param(
'ClusterParameterGroupName'),
"cluster_security_groups": self._get_multi_param(
'ClusterSecurityGroups.member'),
"vpc_security_group_ids": self._get_multi_param(
'VpcSecurityGroupIds.member'),
"cluster_security_groups": self._get_cluster_security_groups(),
"vpc_security_group_ids": self._get_vpc_security_group_ids(),
"preferred_maintenance_window": self._get_param(
'PreferredMaintenanceWindow'),
"automated_snapshot_retention_period": self._get_int_param(
'AutomatedSnapshotRetentionPeriod'),
"region": self.region,
"encrypted": False,
"region_name": self.region,
}
kwargs_from_snapshot.update(kwargs_from_request)
cluster_kwargs = kwargs_from_snapshot
cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json()
cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json()
cluster['ClusterStatus'] = 'creating'
return self.get_response({
"RestoreFromClusterSnapshotResponse": {
@ -169,8 +177,8 @@ class RedshiftResponse(BaseResponse):
"node_type": self._get_param('NodeType'),
"master_user_password": self._get_param('MasterUserPassword'),
"cluster_type": self._get_param('ClusterType'),
"cluster_security_groups": self._get_multi_param('ClusterSecurityGroups.member'),
"vpc_security_group_ids": self._get_multi_param('VpcSecurityGroupIds.member'),
"cluster_security_groups": self._get_cluster_security_groups(),
"vpc_security_group_ids": self._get_vpc_security_group_ids(),
"cluster_subnet_group_name": self._get_param('ClusterSubnetGroupName'),
"preferred_maintenance_window": self._get_param('PreferredMaintenanceWindow'),
"cluster_parameter_group_name": self._get_param('ClusterParameterGroupName'),
@ -181,12 +189,6 @@ class RedshiftResponse(BaseResponse):
"publicly_accessible": self._get_param("PubliclyAccessible"),
"encrypted": self._get_param("Encrypted"),
}
# There's a bug in boto3 where the security group ids are not passed
# according to the AWS documentation
if not request_kwargs['vpc_security_group_ids']:
request_kwargs['vpc_security_group_ids'] = self._get_multi_param(
'VpcSecurityGroupIds.VpcSecurityGroupId')
cluster_kwargs = {}
# We only want parameters that were actually passed in, otherwise
# we'll stomp all over our cluster metadata with None values.
@ -225,16 +227,15 @@ class RedshiftResponse(BaseResponse):
def create_cluster_subnet_group(self):
cluster_subnet_group_name = self._get_param('ClusterSubnetGroupName')
description = self._get_param('Description')
subnet_ids = self._get_multi_param('SubnetIds.member')
# There's a bug in boto3 where the subnet ids are not passed
# according to the AWS documentation
if not subnet_ids:
subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier')
subnet_ids = self._get_subnet_ids()
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
subnet_group = self.redshift_backend.create_cluster_subnet_group(
cluster_subnet_group_name=cluster_subnet_group_name,
description=description,
subnet_ids=subnet_ids,
region_name=self.region,
tags=tags
)
return self.get_response({
@ -280,10 +281,13 @@ class RedshiftResponse(BaseResponse):
cluster_security_group_name = self._get_param(
'ClusterSecurityGroupName')
description = self._get_param('Description')
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
security_group = self.redshift_backend.create_cluster_security_group(
cluster_security_group_name=cluster_security_group_name,
description=description,
region_name=self.region,
tags=tags
)
return self.get_response({
@ -331,11 +335,14 @@ class RedshiftResponse(BaseResponse):
cluster_parameter_group_name = self._get_param('ParameterGroupName')
group_family = self._get_param('ParameterGroupFamily')
description = self._get_param('Description')
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
parameter_group = self.redshift_backend.create_cluster_parameter_group(
cluster_parameter_group_name,
group_family,
description,
self.region,
tags
)
return self.get_response({
@ -381,11 +388,12 @@ class RedshiftResponse(BaseResponse):
def create_cluster_snapshot(self):
cluster_identifier = self._get_param('ClusterIdentifier')
snapshot_identifier = self._get_param('SnapshotIdentifier')
tags = self.unpack_complex_list_params(
'Tags.Tag', ('Key', 'Value'))
snapshot = self.redshift_backend.create_snapshot(cluster_identifier,
snapshot_identifier,
tags)
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
snapshot = self.redshift_backend.create_cluster_snapshot(cluster_identifier,
snapshot_identifier,
self.region,
tags)
return self.get_response({
'CreateClusterSnapshotResponse': {
"CreateClusterSnapshotResult": {
@ -399,9 +407,9 @@ class RedshiftResponse(BaseResponse):
def describe_cluster_snapshots(self):
cluster_identifier = self._get_param('ClusterIdentifier')
snapshot_identifier = self._get_param('DBSnapshotIdentifier')
snapshots = self.redshift_backend.describe_snapshots(cluster_identifier,
snapshot_identifier)
snapshot_identifier = self._get_param('SnapshotIdentifier')
snapshots = self.redshift_backend.describe_cluster_snapshots(cluster_identifier,
snapshot_identifier)
return self.get_response({
"DescribeClusterSnapshotsResponse": {
"DescribeClusterSnapshotsResult": {
@ -415,7 +423,7 @@ class RedshiftResponse(BaseResponse):
def delete_cluster_snapshot(self):
snapshot_identifier = self._get_param('SnapshotIdentifier')
snapshot = self.redshift_backend.delete_snapshot(snapshot_identifier)
snapshot = self.redshift_backend.delete_cluster_snapshot(snapshot_identifier)
return self.get_response({
"DeleteClusterSnapshotResponse": {
@ -428,13 +436,26 @@ class RedshiftResponse(BaseResponse):
}
})
def create_tags(self):
resource_name = self._get_param('ResourceName')
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
self.redshift_backend.create_tags(resource_name, tags)
return self.get_response({
"CreateTagsResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a",
}
}
})
def describe_tags(self):
resource_name = self._get_param('ResourceName')
resource_type = self._get_param('ResourceType')
if resource_type != 'Snapshot':
raise NotImplementedError(
"The describe_tags action has not been fully implemented.")
tagged_resources = \
self.redshift_backend.describe_tags_for_resource_type(resource_type)
tagged_resources = self.redshift_backend.describe_tags(resource_name,
resource_type)
return self.get_response({
"DescribeTagsResponse": {
"DescribeTagsResult": {
@ -445,3 +466,17 @@ class RedshiftResponse(BaseResponse):
}
}
})
def delete_tags(self):
resource_name = self._get_param('ResourceName')
tag_keys = self.unpack_list_params('TagKeys.TagKey')
self.redshift_backend.delete_tags(resource_name, tag_keys)
return self.get_response({
"DeleteTagsResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a",
}
}
})

View File

@ -2,11 +2,20 @@ from __future__ import unicode_literals
from collections import defaultdict
import string
import random
import uuid
from jinja2 import Template
from moto.core import BaseBackend, BaseModel
from moto.core.utils import get_random_hex
ROUTE53_ID_CHOICE = string.ascii_uppercase + string.digits
def create_route53_zone_id():
# New ID's look like this Z1RWWTK7Y8UDDQ
return ''.join([random.choice(ROUTE53_ID_CHOICE) for _ in range(0, 15)])
class HealthCheck(BaseModel):
@ -247,7 +256,7 @@ class Route53Backend(BaseBackend):
self.resource_tags = defaultdict(dict)
def create_hosted_zone(self, name, private_zone, comment=None):
new_id = get_random_hex()
new_id = create_route53_zone_id()
new_zone = FakeZone(
name, new_id, private_zone=private_zone, comment=comment)
self.zones[new_id] = new_zone

View File

@ -91,3 +91,23 @@ class EntityTooSmall(S3ClientError):
"EntityTooSmall",
"Your proposed upload is smaller than the minimum allowed object size.",
*args, **kwargs)
class InvalidRequest(S3ClientError):
code = 400
def __init__(self, method, *args, **kwargs):
super(InvalidRequest, self).__init__(
"InvalidRequest",
"Found unsupported HTTP method in CORS config. Unsupported method is {}".format(method),
*args, **kwargs)
class MalformedXML(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
super(MalformedXML, self).__init__(
"MalformedXML",
"The XML you provided was not well-formed or did not validate against our published schema",
*args, **kwargs)

View File

@ -201,10 +201,18 @@ class FakeGrantee(BaseModel):
self.uri = uri
self.display_name = display_name
def __eq__(self, other):
if not isinstance(other, FakeGrantee):
return False
return self.id == other.id and self.uri == other.uri and self.display_name == other.display_name
@property
def type(self):
return 'Group' if self.uri else 'CanonicalUser'
def __repr__(self):
return "FakeGrantee(display_name: '{}', id: '{}', uri: '{}')".format(self.display_name, self.id, self.uri)
ALL_USERS_GRANTEE = FakeGrantee(
uri='http://acs.amazonaws.com/groups/global/AllUsers')
@ -226,12 +234,28 @@ class FakeGrant(BaseModel):
self.grantees = grantees
self.permissions = permissions
def __repr__(self):
return "FakeGrant(grantees: {}, permissions: {})".format(self.grantees, self.permissions)
class FakeAcl(BaseModel):
def __init__(self, grants=[]):
self.grants = grants
@property
def public_read(self):
for grant in self.grants:
if ALL_USERS_GRANTEE in grant.grantees:
if PERMISSION_READ in grant.permissions:
return True
if PERMISSION_FULL_CONTROL in grant.permissions:
return True
return False
def __repr__(self):
return "FakeAcl(grants: {})".format(self.grants)
def get_canned_acl(acl):
owner_grantee = FakeGrantee(
@ -295,6 +319,26 @@ class LifecycleRule(BaseModel):
self.storage_class = storage_class
class CorsRule(BaseModel):
def __init__(self, allowed_methods, allowed_origins, allowed_headers=None, expose_headers=None,
max_age_seconds=None):
# Python 2 and 3 have different string types for handling unicodes. Python 2 wants `basestring`,
# whereas Python 3 is OK with str. This causes issues with the XML parser, which returns
# unicode strings in Python 2. So, need to do this to make it work in both Python 2 and 3:
import sys
if sys.version_info >= (3, 0):
str_type = str
else:
str_type = basestring # noqa
self.allowed_methods = [allowed_methods] if isinstance(allowed_methods, str_type) else allowed_methods
self.allowed_origins = [allowed_origins] if isinstance(allowed_origins, str_type) else allowed_origins
self.allowed_headers = [allowed_headers] if isinstance(allowed_headers, str_type) else allowed_headers
self.exposed_headers = [expose_headers] if isinstance(expose_headers, str_type) else expose_headers
self.max_age_seconds = max_age_seconds
class FakeBucket(BaseModel):
def __init__(self, name, region_name):
@ -307,6 +351,8 @@ class FakeBucket(BaseModel):
self.policy = None
self.website_configuration = None
self.acl = get_canned_acl('private')
self.tags = FakeTagging()
self.cors = []
@property
def location(self):
@ -336,6 +382,61 @@ class FakeBucket(BaseModel):
def delete_lifecycle(self):
self.rules = []
def set_cors(self, rules):
from moto.s3.exceptions import InvalidRequest, MalformedXML
self.cors = []
if len(rules) > 100:
raise MalformedXML()
# Python 2 and 3 have different string types for handling unicodes. Python 2 wants `basestring`,
# whereas Python 3 is OK with str. This causes issues with the XML parser, which returns
# unicode strings in Python 2. So, need to do this to make it work in both Python 2 and 3:
import sys
if sys.version_info >= (3, 0):
str_type = str
else:
str_type = basestring # noqa
for rule in rules:
assert isinstance(rule["AllowedMethod"], list) or isinstance(rule["AllowedMethod"], str_type)
assert isinstance(rule["AllowedOrigin"], list) or isinstance(rule["AllowedOrigin"], str_type)
assert isinstance(rule.get("AllowedHeader", []), list) or isinstance(rule.get("AllowedHeader", ""),
str_type)
assert isinstance(rule.get("ExposedHeader", []), list) or isinstance(rule.get("ExposedHeader", ""),
str_type)
assert isinstance(rule.get("MaxAgeSeconds", "0"), str_type)
if isinstance(rule["AllowedMethod"], str_type):
methods = [rule["AllowedMethod"]]
else:
methods = rule["AllowedMethod"]
for method in methods:
if method not in ["GET", "PUT", "HEAD", "POST", "DELETE"]:
raise InvalidRequest(method)
self.cors.append(CorsRule(
rule["AllowedMethod"],
rule["AllowedOrigin"],
rule.get("AllowedHeader"),
rule.get("ExposedHeader"),
rule.get("MaxAgeSecond")
))
def delete_cors(self):
self.cors = []
def set_tags(self, tagging):
self.tags = tagging
def delete_tags(self):
self.tags = FakeTagging()
@property
def tagging(self):
return self.tags
def set_website_configuration(self, website_configuration):
self.website_configuration = website_configuration
@ -422,14 +523,15 @@ class S3Backend(BaseBackend):
encoding_type=None,
key_marker=None,
max_keys=None,
version_id_marker=None):
version_id_marker=None,
prefix=''):
bucket = self.get_bucket(bucket_name)
if any((delimiter, encoding_type, key_marker, version_id_marker)):
raise NotImplementedError(
"Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker")
return itertools.chain(*(l for _, l in bucket.keys.iterlists()))
return itertools.chain(*(l for key, l in bucket.keys.iterlists() if key.startswith(prefix)))
def get_bucket_policy(self, bucket_name):
return self.get_bucket(bucket_name).policy
@ -509,6 +611,22 @@ class S3Backend(BaseBackend):
key.set_tagging(tagging)
return key
def put_bucket_tagging(self, bucket_name, tagging):
bucket = self.get_bucket(bucket_name)
bucket.set_tags(tagging)
def delete_bucket_tagging(self, bucket_name):
bucket = self.get_bucket(bucket_name)
bucket.delete_tags()
def put_bucket_cors(self, bucket_name, cors_rules):
bucket = self.get_bucket(bucket_name)
bucket.set_cors(cors_rules)
def delete_bucket_cors(self, bucket_name):
bucket = self.get_bucket(bucket_name)
bucket.delete_cors()
def initiate_multipart(self, bucket_name, key_name, metadata):
bucket = self.get_bucket(bucket_name)
new_multipart = FakeMultipart(key_name, metadata)

236
moto/s3/responses.py Normal file → Executable file
View File

@ -188,7 +188,8 @@ class ResponseObject(_TemplateEnvironmentMixin):
elif 'lifecycle' in querystring:
bucket = self.backend.get_bucket(bucket_name)
if not bucket.rules:
return 404, {}, "NoSuchLifecycleConfiguration"
template = self.response_template(S3_NO_LIFECYCLE)
return 404, {}, template.render(bucket_name=bucket_name)
template = self.response_template(
S3_BUCKET_LIFECYCLE_CONFIGURATION)
return template.render(rules=bucket.rules)
@ -205,17 +206,35 @@ class ResponseObject(_TemplateEnvironmentMixin):
elif 'website' in querystring:
website_configuration = self.backend.get_bucket_website_configuration(
bucket_name)
if not website_configuration:
template = self.response_template(S3_NO_BUCKET_WEBSITE_CONFIG)
return 404, {}, template.render(bucket_name=bucket_name)
return website_configuration
elif 'acl' in querystring:
bucket = self.backend.get_bucket(bucket_name)
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return template.render(obj=bucket)
elif 'tagging' in querystring:
bucket = self.backend.get_bucket(bucket_name)
# "Special Error" if no tags:
if len(bucket.tagging.tag_set.tags) == 0:
template = self.response_template(S3_NO_BUCKET_TAGGING)
return 404, {}, template.render(bucket_name=bucket_name)
template = self.response_template(S3_BUCKET_TAGGING_RESPONSE)
return template.render(bucket=bucket)
elif "cors" in querystring:
bucket = self.backend.get_bucket(bucket_name)
if len(bucket.cors) == 0:
template = self.response_template(S3_NO_CORS_CONFIG)
return 404, {}, template.render(bucket_name=bucket_name)
template = self.response_template(S3_BUCKET_CORS_RESPONSE)
return template.render(bucket=bucket)
elif 'versions' in querystring:
delimiter = querystring.get('delimiter', [None])[0]
encoding_type = querystring.get('encoding-type', [None])[0]
key_marker = querystring.get('key-marker', [None])[0]
max_keys = querystring.get('max-keys', [None])[0]
prefix = querystring.get('prefix', [None])[0]
prefix = querystring.get('prefix', [''])[0]
version_id_marker = querystring.get('version-id-marker', [None])[0]
bucket = self.backend.get_bucket(bucket_name)
@ -225,7 +244,8 @@ class ResponseObject(_TemplateEnvironmentMixin):
encoding_type=encoding_type,
key_marker=key_marker,
max_keys=max_keys,
version_id_marker=version_id_marker
version_id_marker=version_id_marker,
prefix=prefix
)
latest_versions = self.backend.get_bucket_latest_versions(
bucket_name=bucket_name
@ -256,15 +276,25 @@ class ResponseObject(_TemplateEnvironmentMixin):
if prefix and isinstance(prefix, six.binary_type):
prefix = prefix.decode("utf-8")
delimiter = querystring.get('delimiter', [None])[0]
max_keys = int(querystring.get('max-keys', [1000])[0])
marker = querystring.get('marker', [None])[0]
result_keys, result_folders = self.backend.prefix_query(
bucket, prefix, delimiter)
if marker:
result_keys = self._get_results_from_token(result_keys, marker)
result_keys, is_truncated, _ = self._truncate_result(result_keys, max_keys)
template = self.response_template(S3_BUCKET_GET_RESPONSE)
return 200, {}, template.render(
bucket=bucket,
prefix=prefix,
delimiter=delimiter,
result_keys=result_keys,
result_folders=result_folders
result_folders=result_folders,
is_truncated=is_truncated,
max_keys=max_keys
)
def _handle_list_objects_v2(self, bucket_name, querystring):
@ -285,20 +315,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
if continuation_token or start_after:
limit = continuation_token or start_after
continuation_index = 0
for key in result_keys:
if key.name > limit:
break
continuation_index += 1
result_keys = result_keys[continuation_index:]
result_keys = self._get_results_from_token(result_keys, limit)
if len(result_keys) > max_keys:
is_truncated = 'true'
result_keys = result_keys[:max_keys]
next_continuation_token = result_keys[-1].name
else:
is_truncated = 'false'
next_continuation_token = None
result_keys, is_truncated, \
next_continuation_token = self._truncate_result(result_keys, max_keys)
return template.render(
bucket=bucket,
@ -313,6 +333,24 @@ class ResponseObject(_TemplateEnvironmentMixin):
start_after=None if continuation_token else start_after
)
def _get_results_from_token(self, result_keys, token):
continuation_index = 0
for key in result_keys:
if key.name > token:
break
continuation_index += 1
return result_keys[continuation_index:]
def _truncate_result(self, result_keys, max_keys):
if len(result_keys) > max_keys:
is_truncated = 'true'
result_keys = result_keys[:max_keys]
next_continuation_token = result_keys[-1].name
else:
is_truncated = 'false'
next_continuation_token = None
return result_keys, is_truncated, next_continuation_token
def _bucket_response_put(self, request, body, region_name, bucket_name, querystring, headers):
if not request.headers.get('Content-Length'):
return 411, {}, "Content-Length required"
@ -335,13 +373,23 @@ class ResponseObject(_TemplateEnvironmentMixin):
self.backend.set_bucket_policy(bucket_name, body)
return 'True'
elif 'acl' in querystring:
acl = self._acl_from_headers(request.headers)
# TODO: Support the XML-based ACL format
self.backend.set_bucket_acl(bucket_name, acl)
self.backend.set_bucket_acl(bucket_name, self._acl_from_headers(request.headers))
return ""
elif "tagging" in querystring:
tagging = self._bucket_tagging_from_xml(body)
self.backend.put_bucket_tagging(bucket_name, tagging)
return ""
elif 'website' in querystring:
self.backend.set_bucket_website_configuration(bucket_name, body)
return ""
elif "cors" in querystring:
from moto.s3.exceptions import MalformedXML
try:
self.backend.put_bucket_cors(bucket_name, self._cors_from_xml(body))
return ""
except KeyError:
raise MalformedXML()
else:
if body:
try:
@ -358,6 +406,11 @@ class ResponseObject(_TemplateEnvironmentMixin):
new_bucket = self.backend.get_bucket(bucket_name)
else:
raise
if 'x-amz-acl' in request.headers:
# TODO: Support the XML-based ACL format
self.backend.set_bucket_acl(bucket_name, self._acl_from_headers(request.headers))
template = self.response_template(S3_BUCKET_CREATE_RESPONSE)
return 200, {}, template.render(bucket=new_bucket)
@ -365,6 +418,12 @@ class ResponseObject(_TemplateEnvironmentMixin):
if 'policy' in querystring:
self.backend.delete_bucket_policy(bucket_name, body)
return 204, {}, ""
elif "tagging" in querystring:
self.backend.delete_bucket_tagging(bucket_name)
return 204, {}, ""
elif "cors" in querystring:
self.backend.delete_bucket_cors(bucket_name)
return 204, {}, ""
elif 'lifecycle' in querystring:
bucket = self.backend.get_bucket(bucket_name)
bucket.delete_lifecycle()
@ -481,6 +540,23 @@ class ResponseObject(_TemplateEnvironmentMixin):
key_name = self.parse_key_name(request, parsed_url.path)
bucket_name = self.parse_bucket_name_from_url(request, full_url)
# Because we patch the requests library the boto/boto3 API
# requests go through this method but so do
# `requests.get("https://bucket-name.s3.amazonaws.com/file-name")`
# Here we deny public access to private files by checking the
# ACL and checking for the mere presence of an Authorization
# header.
if 'Authorization' not in request.headers:
if hasattr(request, 'url'):
signed_url = 'Signature=' in request.url
elif hasattr(request, 'requestline'):
signed_url = 'Signature=' in request.path
key = self.backend.get_key(bucket_name, key_name)
if key:
if not key.acl.public_read and not signed_url:
return 403, {}, ""
if hasattr(request, 'body'):
# Boto
body = request.body
@ -566,6 +642,8 @@ class ResponseObject(_TemplateEnvironmentMixin):
storage_class = request.headers.get('x-amz-storage-class', 'STANDARD')
acl = self._acl_from_headers(request.headers)
if acl is None:
acl = self.backend.get_bucket(bucket_name).acl
tagging = self._tagging_from_headers(request.headers)
if 'acl' in query:
@ -696,6 +774,32 @@ class ResponseObject(_TemplateEnvironmentMixin):
tagging = FakeTagging(tag_set)
return tagging
def _bucket_tagging_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
tags = []
# Optional if no tags are being sent:
if parsed_xml['Tagging'].get('TagSet'):
# If there is only 1 tag, then it's not a list:
if not isinstance(parsed_xml['Tagging']['TagSet']['Tag'], list):
tags.append(FakeTag(parsed_xml['Tagging']['TagSet']['Tag']['Key'],
parsed_xml['Tagging']['TagSet']['Tag']['Value']))
else:
for tag in parsed_xml['Tagging']['TagSet']['Tag']:
tags.append(FakeTag(tag['Key'], tag['Value']))
tag_set = FakeTagSet(tags)
tagging = FakeTagging(tag_set)
return tagging
def _cors_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list):
return [cors for cors in parsed_xml["CORSConfiguration"]["CORSRule"]]
return [parsed_xml["CORSConfiguration"]["CORSRule"]]
def _key_response_delete(self, bucket_name, query, key_name, headers):
if query.get('uploadId'):
upload_id = query['uploadId'][0]
@ -775,9 +879,9 @@ S3_BUCKET_GET_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<Name>{{ bucket.name }}</Name>
<Prefix>{{ prefix }}</Prefix>
<MaxKeys>1000</MaxKeys>
<MaxKeys>{{ max_keys }}</MaxKeys>
<Delimiter>{{ delimiter }}</Delimiter>
<IsTruncated>false</IsTruncated>
<IsTruncated>{{ is_truncated }}</IsTruncated>
{% for key in result_keys %}
<Contents>
<Key>{{ key.name }}</Key>
@ -1022,6 +1126,46 @@ S3_OBJECT_TAGGING_RESPONSE = """\
</TagSet>
</Tagging>"""
S3_BUCKET_TAGGING_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Tagging>
<TagSet>
{% for tag in bucket.tagging.tag_set.tags %}
<Tag>
<Key>{{ tag.key }}</Key>
<Value>{{ tag.value }}</Value>
</Tag>
{% endfor %}
</TagSet>
</Tagging>"""
S3_BUCKET_CORS_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<CORSConfiguration>
{% for cors in bucket.cors %}
<CORSRule>
{% for origin in cors.allowed_origins %}
<AllowedOrigin>{{ origin }}</AllowedOrigin>
{% endfor %}
{% for method in cors.allowed_methods %}
<AllowedMethod>{{ method }}</AllowedMethod>
{% endfor %}
{% if cors.allowed_headers is not none %}
{% for header in cors.allowed_headers %}
<AllowedHeader>{{ header }}</AllowedHeader>
{% endfor %}
{% endif %}
{% if cors.exposed_headers is not none %}
{% for header in cors.exposed_headers %}
<ExposedHeader>{{ header }}</ExposedHeader>
{% endfor %}
{% endif %}
{% if cors.max_age_seconds is not none %}
<MaxAgeSeconds>{{ cors.max_age_seconds }}</MaxAgeSeconds>
{% endif %}
</CORSRule>
{% endfor %}
</CORSConfiguration>
"""
S3_OBJECT_COPY_RESPONSE = """\
<CopyObjectResult xmlns="http://doc.s3.amazonaws.com/2006-03-01">
<ETag>{{ key.etag }}</ETag>
@ -1114,3 +1258,53 @@ S3_NO_POLICY = """<?xml version="1.0" encoding="UTF-8"?>
<HostId>9Gjjt1m+cjU4OPvX9O9/8RuvnG41MRb/18Oux2o5H5MY7ISNTlXN+Dz9IG62/ILVxhAGI0qyPfg=</HostId>
</Error>
"""
S3_NO_LIFECYCLE = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>NoSuchLifecycleConfiguration</Code>
<Message>The lifecycle configuration does not exist</Message>
<BucketName>{{ bucket_name }}</BucketName>
<RequestId>44425877V1D0A2F9</RequestId>
<HostId>9Gjjt1m+cjU4OPvX9O9/8RuvnG41MRb/18Oux2o5H5MY7ISNTlXN+Dz9IG62/ILVxhAGI0qyPfg=</HostId>
</Error>
"""
S3_NO_BUCKET_TAGGING = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>NoSuchTagSet</Code>
<Message>The TagSet does not exist</Message>
<BucketName>{{ bucket_name }}</BucketName>
<RequestId>44425877V1D0A2F9</RequestId>
<HostId>9Gjjt1m+cjU4OPvX9O9/8RuvnG41MRb/18Oux2o5H5MY7ISNTlXN+Dz9IG62/ILVxhAGI0qyPfg=</HostId>
</Error>
"""
S3_NO_BUCKET_WEBSITE_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>NoSuchWebsiteConfiguration</Code>
<Message>The specified bucket does not have a website configuration</Message>
<BucketName>{{ bucket_name }}</BucketName>
<RequestId>44425877V1D0A2F9</RequestId>
<HostId>9Gjjt1m+cjU4OPvX9O9/8RuvnG41MRb/18Oux2o5H5MY7ISNTlXN+Dz9IG62/ILVxhAGI0qyPfg=</HostId>
</Error>
"""
S3_INVALID_CORS_REQUEST = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>NoSuchWebsiteConfiguration</Code>
<Message>The specified bucket does not have a website configuration</Message>
<BucketName>{{ bucket_name }}</BucketName>
<RequestId>44425877V1D0A2F9</RequestId>
<HostId>9Gjjt1m+cjU4OPvX9O9/8RuvnG41MRb/18Oux2o5H5MY7ISNTlXN+Dz9IG62/ILVxhAGI0qyPfg=</HostId>
</Error>
"""
S3_NO_CORS_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>NoSuchCORSConfiguration</Code>
<Message>The CORS configuration does not exist</Message>
<BucketName>{{ bucket_name }}</BucketName>
<RequestId>44425877V1D0A2F9</RequestId>
<HostId>9Gjjt1m+cjU4OPvX9O9/8RuvnG41MRb/18Oux2o5H5MY7ISNTlXN+Dz9IG62/ILVxhAGI0qyPfg=</HostId>
</Error>
"""

View File

@ -4,7 +4,7 @@ from .responses import S3ResponseInstance
url_bases = [
"https?://s3(.*).amazonaws.com",
"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com"
r"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com"
]

View File

@ -1,22 +1,23 @@
from __future__ import unicode_literals
import argparse
import json
import re
import sys
import argparse
import six
from six.moves.urllib.parse import urlencode
from threading import Lock
import six
from flask import Flask
from flask.testing import FlaskClient
from six.moves.urllib.parse import urlencode
from werkzeug.routing import BaseConverter
from werkzeug.serving import run_simple
from moto.backends import BACKENDS
from moto.core.utils import convert_flask_to_httpretty_response
HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"]
@ -61,7 +62,7 @@ class DomainDispatcherApplication(object):
host = "instance_metadata"
else:
host = environ['HTTP_HOST'].split(':')[0]
if host == "localhost":
if host in {'localhost', 'motoserver'} or host.startswith("192.168."):
# Fall back to parsing auth header to find service
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
try:
@ -139,10 +140,13 @@ def create_backend_app(service):
else:
endpoint = None
if endpoint in backend_app.view_functions:
original_endpoint = endpoint
index = 2
while endpoint in backend_app.view_functions:
# HACK: Sometimes we map the same view to multiple url_paths. Flask
# requries us to have different names.
endpoint += "2"
endpoint = original_endpoint + str(index)
index += 1
backend_app.add_url_rule(
url_path,

View File

@ -24,3 +24,11 @@ class SnsEndpointDisabled(RESTError):
def __init__(self, message):
super(SnsEndpointDisabled, self).__init__(
"EndpointDisabled", message)
class SNSInvalidParameter(RESTError):
code = 400
def __init__(self, message):
super(SNSInvalidParameter, self).__init__(
"InvalidParameter", message)

View File

@ -12,8 +12,10 @@ from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.sqs import sqs_backends
from moto.awslambda import lambda_backends
from .exceptions import (
SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled
SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter
)
from .utils import make_arn_for_topic, make_arn_for_subscription
@ -76,15 +78,23 @@ class Subscription(BaseModel):
self.endpoint = endpoint
self.protocol = protocol
self.arn = make_arn_for_subscription(self.topic.arn)
self.attributes = {}
self.confirmed = False
def publish(self, message, message_id):
if self.protocol == 'sqs':
queue_name = self.endpoint.split(":")[-1]
region = self.endpoint.split(":")[3]
sqs_backends[region].send_message(queue_name, message)
enveloped_message = json.dumps(self.get_post_data(message, message_id), sort_keys=True, indent=2, separators=(',', ': '))
sqs_backends[region].send_message(queue_name, enveloped_message)
elif self.protocol in ['http', 'https']:
post_data = self.get_post_data(message, message_id)
requests.post(self.endpoint, json=post_data)
elif self.protocol == 'lambda':
# TODO: support bad function name
function_name = self.endpoint.split(":")[-1]
region = self.arn.split(':')[3]
lambda_backends[region].send_message(function_name, message)
def get_post_data(self, message, message_id):
return {
@ -170,12 +180,18 @@ class SNSBackend(BaseBackend):
self.applications = {}
self.platform_endpoints = {}
self.region_name = region_name
self.sms_attributes = {}
self.opt_out_numbers = ['+447420500600', '+447420505401', '+447632960543', '+447632960028', '+447700900149', '+447700900550', '+447700900545', '+447700900907']
self.permissions = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def update_sms_attributes(self, attrs):
self.sms_attributes.update(attrs)
def create_topic(self, name):
topic = Topic(name, self)
self.topics[topic.arn] = topic
@ -212,6 +228,12 @@ class SNSBackend(BaseBackend):
except KeyError:
raise SNSNotFoundError("Topic with arn {0} not found".format(arn))
def get_topic_from_phone_number(self, number):
for subscription in self.subscriptions.values():
if subscription.protocol == 'sms' and subscription.endpoint == number:
return subscription.topic.arn
raise SNSNotFoundError('Could not find valid subscription')
def set_topic_attribute(self, topic_arn, attribute_name, attribute_value):
topic = self.get_topic(topic_arn)
setattr(topic, attribute_name, attribute_value)
@ -300,6 +322,26 @@ class SNSBackend(BaseBackend):
raise SNSNotFoundError(
"Endpoint with arn {0} not found".format(arn))
def get_subscription_attributes(self, arn):
_subscription = [_ for _ in self.subscriptions.values() if _.arn == arn]
if not _subscription:
raise SNSNotFoundError("Subscription with arn {0} not found".format(arn))
subscription = _subscription[0]
return subscription.attributes
def set_subscription_attributes(self, arn, name, value):
if name not in ['RawMessageDelivery', 'DeliveryPolicy']:
raise SNSInvalidParameter('AttributeName')
# TODO: should do validation
_subscription = [_ for _ in self.subscriptions.values() if _.arn == arn]
if not _subscription:
raise SNSNotFoundError("Subscription with arn {0} not found".format(arn))
subscription = _subscription[0]
subscription.attributes[name] = value
sns_backends = {}
for region in boto.sns.regions():

View File

@ -1,17 +1,27 @@
from __future__ import unicode_literals
import json
import re
from collections import defaultdict
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from .models import sns_backends
from .exceptions import SNSNotFoundError
from .utils import is_e164
class SNSResponse(BaseResponse):
SMS_ATTR_REGEX = re.compile(r'^attributes\.entry\.(?P<index>\d+)\.(?P<type>key|value)$')
OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r'^\+?\d+$')
@property
def backend(self):
return sns_backends[self.region]
def _error(self, code, message, sender='Sender'):
template = self.response_template(ERROR_RESPONSE)
return template.render(code=code, message=message, sender=sender)
def _get_attributes(self):
attributes = self._get_list_prefix('Attributes.entry')
return dict(
@ -128,6 +138,13 @@ class SNSResponse(BaseResponse):
topic_arn = self._get_param('TopicArn')
endpoint = self._get_param('Endpoint')
protocol = self._get_param('Protocol')
if protocol == 'sms' and not is_e164(endpoint):
return self._error(
'InvalidParameter',
'Phone number does not meet the E164 format'
), dict(status=400)
subscription = self.backend.subscribe(topic_arn, endpoint, protocol)
if self.request_json:
@ -221,7 +238,28 @@ class SNSResponse(BaseResponse):
def publish(self):
target_arn = self._get_param('TargetArn')
topic_arn = self._get_param('TopicArn')
arn = target_arn if target_arn else topic_arn
phone_number = self._get_param('PhoneNumber')
if phone_number is not None:
# Check phone is correct syntax (e164)
if not is_e164(phone_number):
return self._error(
'InvalidParameter',
'Phone number does not meet the E164 format'
), dict(status=400)
# Look up topic arn by phone number
try:
arn = self.backend.get_topic_from_phone_number(phone_number)
except SNSNotFoundError:
return self._error(
'ParameterValueInvalid',
'Could not find topic associated with phone number'
), dict(status=400)
elif target_arn is not None:
arn = target_arn
else:
arn = topic_arn
message = self._get_param('Message')
message_id = self.backend.publish(arn, message)
@ -445,6 +483,145 @@ class SNSResponse(BaseResponse):
template = self.response_template(DELETE_ENDPOINT_TEMPLATE)
return template.render()
def get_subscription_attributes(self):
arn = self._get_param('SubscriptionArn')
attributes = self.backend.get_subscription_attributes(arn)
template = self.response_template(GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE)
return template.render(attributes=attributes)
def set_subscription_attributes(self):
arn = self._get_param('SubscriptionArn')
attr_name = self._get_param('AttributeName')
attr_value = self._get_param('AttributeValue')
self.backend.set_subscription_attributes(arn, attr_name, attr_value)
template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE)
return template.render()
def set_sms_attributes(self):
# attributes.entry.1.key
# attributes.entry.1.value
# to
# 1: {key:X, value:Y}
temp_dict = defaultdict(dict)
for key, value in self.querystring.items():
match = self.SMS_ATTR_REGEX.match(key)
if match is not None:
temp_dict[match.group('index')][match.group('type')] = value[0]
# 1: {key:X, value:Y}
# to
# X: Y
# All of this, just to take into account when people provide invalid stuff.
result = {}
for item in temp_dict.values():
if 'key' in item and 'value' in item:
result[item['key']] = item['value']
self.backend.update_sms_attributes(result)
template = self.response_template(SET_SMS_ATTRIBUTES_TEMPLATE)
return template.render()
def get_sms_attributes(self):
filter_list = set()
for key, value in self.querystring.items():
if key.startswith('attributes.member.1'):
filter_list.add(value[0])
if len(filter_list) > 0:
result = {k: v for k, v in self.backend.sms_attributes.items() if k in filter_list}
else:
result = self.backend.sms_attributes
template = self.response_template(GET_SMS_ATTRIBUTES_TEMPLATE)
return template.render(attributes=result)
def check_if_phone_number_is_opted_out(self):
number = self._get_param('phoneNumber')
if self.OPT_OUT_PHONE_NUMBER_REGEX.match(number) is None:
error_response = self._error(
code='InvalidParameter',
message='Invalid parameter: PhoneNumber Reason: input incorrectly formatted'
)
return error_response, dict(status=400)
# There should be a nicer way to set if a nubmer has opted out
template = self.response_template(CHECK_IF_OPTED_OUT_TEMPLATE)
return template.render(opt_out=str(number.endswith('99')).lower())
def list_phone_numbers_opted_out(self):
template = self.response_template(LIST_OPTOUT_TEMPLATE)
return template.render(opt_outs=self.backend.opt_out_numbers)
def opt_in_phone_number(self):
number = self._get_param('phoneNumber')
try:
self.backend.opt_out_numbers.remove(number)
except ValueError:
pass
template = self.response_template(OPT_IN_NUMBER_TEMPLATE)
return template.render()
def add_permission(self):
arn = self._get_param('TopicArn')
label = self._get_param('Label')
accounts = self._get_multi_param('AWSAccountId.member.')
action = self._get_multi_param('ActionName.member.')
if arn not in self.backend.topics:
error_response = self._error('NotFound', 'Topic does not exist')
return error_response, dict(status=404)
key = (arn, label)
self.backend.permissions[key] = {'accounts': accounts, 'action': action}
template = self.response_template(ADD_PERMISSION_TEMPLATE)
return template.render()
def remove_permission(self):
arn = self._get_param('TopicArn')
label = self._get_param('Label')
if arn not in self.backend.topics:
error_response = self._error('NotFound', 'Topic does not exist')
return error_response, dict(status=404)
try:
key = (arn, label)
del self.backend.permissions[key]
except KeyError:
pass
template = self.response_template(DEL_PERMISSION_TEMPLATE)
return template.render()
def confirm_subscription(self):
arn = self._get_param('TopicArn')
if arn not in self.backend.topics:
error_response = self._error('NotFound', 'Topic does not exist')
return error_response, dict(status=404)
# Once Tokens are stored by the `subscribe` endpoint and distributed
# to the client somehow, then we can check validity of tokens
# presented to this method. The following code works, all thats
# needed is to perform a token check and assign that value to the
# `already_subscribed` variable.
#
# token = self._get_param('Token')
# auth = self._get_param('AuthenticateOnUnsubscribe')
# if already_subscribed:
# error_response = self._error(
# code='AuthorizationError',
# message='Subscription already confirmed'
# )
# return error_response, dict(status=400)
template = self.response_template(CONFIRM_SUBSCRIPTION_TEMPLATE)
return template.render(sub_arn='{0}:68762e72-e9b1-410a-8b3b-903da69ee1d5'.format(arn))
CREATE_TOPIC_TEMPLATE = """<CreateTopicResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<CreateTopicResult>
@ -719,3 +896,110 @@ LIST_SUBSCRIPTIONS_BY_TOPIC_TEMPLATE = """<ListSubscriptionsByTopicResponse xmln
<RequestId>384ac68d-3775-11df-8963-01868b7c937a</RequestId>
</ResponseMetadata>
</ListSubscriptionsByTopicResponse>"""
# Not responding aws system attribetus like 'Owner' and 'SubscriptionArn'
GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE = """<GetSubscriptionAttributesResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<GetSubscriptionAttributesResult>
<Attributes>
{% for name, value in attributes.items() %}
<entry>
<key>{{ name }}</key>
<value>{{ value }}</value>
</entry>
{% endfor %}
</Attributes>
</GetSubscriptionAttributesResult>
<ResponseMetadata>
<RequestId>057f074c-33a7-11df-9540-99d0768312d3</RequestId>
</ResponseMetadata>
</GetSubscriptionAttributesResponse>"""
SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE = """<SetSubscriptionAttributesResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<ResponseMetadata>
<RequestId>a8763b99-33a7-11df-a9b7-05d48da6f042</RequestId>
</ResponseMetadata>
</SetSubscriptionAttributesResponse>"""
SET_SMS_ATTRIBUTES_TEMPLATE = """<SetSMSAttributesResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<SetSMSAttributesResult/>
<ResponseMetadata>
<RequestId>26332069-c04a-5428-b829-72524b56a364</RequestId>
</ResponseMetadata>
</SetSMSAttributesResponse>"""
GET_SMS_ATTRIBUTES_TEMPLATE = """<GetSMSAttributesResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<GetSMSAttributesResult>
<attributes>
{% for name, value in attributes.items() %}
<entry>
<key>{{ name }}</key>
<value>{{ value }}</value>
</entry>
{% endfor %}
</attributes>
</GetSMSAttributesResult>
<ResponseMetadata>
<RequestId>287f9554-8db3-5e66-8abc-c76f0186db7e</RequestId>
</ResponseMetadata>
</GetSMSAttributesResponse>"""
CHECK_IF_OPTED_OUT_TEMPLATE = """<CheckIfPhoneNumberIsOptedOutResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<CheckIfPhoneNumberIsOptedOutResult>
<isOptedOut>{{ opt_out }}</isOptedOut>
</CheckIfPhoneNumberIsOptedOutResult>
<ResponseMetadata>
<RequestId>287f9554-8db3-5e66-8abc-c76f0186db7e</RequestId>
</ResponseMetadata>
</CheckIfPhoneNumberIsOptedOutResponse>"""
ERROR_RESPONSE = """<ErrorResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<Error>
<Type>{{ sender }}</Type>
<Code>{{ code }}</Code>
<Message>{{ message }}</Message>
</Error>
<RequestId>9dd01905-5012-5f99-8663-4b3ecd0dfaef</RequestId>
</ErrorResponse>"""
LIST_OPTOUT_TEMPLATE = """<ListPhoneNumbersOptedOutResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<ListPhoneNumbersOptedOutResult>
<phoneNumbers>
{% for item in opt_outs %}
<member>{{ item }}</member>
{% endfor %}
</phoneNumbers>
</ListPhoneNumbersOptedOutResult>
<ResponseMetadata>
<RequestId>985e196d-a237-51b6-b33a-4b5601276b38</RequestId>
</ResponseMetadata>
</ListPhoneNumbersOptedOutResponse>"""
OPT_IN_NUMBER_TEMPLATE = """<OptInPhoneNumberResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<OptInPhoneNumberResult/>
<ResponseMetadata>
<RequestId>4c61842c-0796-50ef-95ac-d610c0bc8cf8</RequestId>
</ResponseMetadata>
</OptInPhoneNumberResponse>"""
ADD_PERMISSION_TEMPLATE = """<AddPermissionResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<ResponseMetadata>
<RequestId>c046e713-c5ff-5888-a7bc-b52f0e4f1299</RequestId>
</ResponseMetadata>
</AddPermissionResponse>"""
DEL_PERMISSION_TEMPLATE = """<RemovePermissionResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<ResponseMetadata>
<RequestId>e767cc9f-314b-5e1b-b283-9ea3fd4e38a3</RequestId>
</ResponseMetadata>
</RemovePermissionResponse>"""
CONFIRM_SUBSCRIPTION_TEMPLATE = """<ConfirmSubscriptionResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<ConfirmSubscriptionResult>
<SubscriptionArn>{{ sub_arn }}</SubscriptionArn>
</ConfirmSubscriptionResult>
<ResponseMetadata>
<RequestId>16eb4dde-7b3c-5b3e-a22a-1fe2a92d3293</RequestId>
</ResponseMetadata>
</ConfirmSubscriptionResponse>"""

View File

@ -1,6 +1,9 @@
from __future__ import unicode_literals
import re
import uuid
E164_REGEX = re.compile(r'^\+?[1-9]\d{1,14}$')
def make_arn_for_topic(account_id, name, region_name):
return "arn:aws:sns:{0}:{1}:{2}".format(region_name, account_id, name)
@ -9,3 +12,7 @@ def make_arn_for_topic(account_id, name, region_name):
def make_arn_for_subscription(topic_arn):
subscription_id = uuid.uuid4()
return "{0}:{1}".format(topic_arn, subscription_id)
def is_e164(number):
return E164_REGEX.match(number) is not None

View File

@ -12,10 +12,7 @@ import boto.sqs
from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores, get_random_message_id, unix_time, unix_time_millis
from .utils import generate_receipt_handle
from .exceptions import (
ReceiptHandleIsInvalid,
MessageNotInflight
)
from .exceptions import ReceiptHandleIsInvalid, MessageNotInflight, MessageAttributesInvalid
DEFAULT_ACCOUNT_ID = 123456789012
DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
@ -59,6 +56,7 @@ class Message(BaseModel):
return str.encode('utf-8')
return str
md5 = hashlib.md5()
struct_format = "!I".encode('ascii') # ensure it's a bytestring
for name in sorted(self.message_attributes.keys()):
attr = self.message_attributes[name]
data_type = attr['data_type']
@ -67,10 +65,10 @@ class Message(BaseModel):
# Each part of each attribute is encoded right after it's
# own length is packed into a 4-byte integer
# 'timestamp' -> b'\x00\x00\x00\t'
encoded += struct.pack("!I", len(utf8(name))) + utf8(name)
encoded += struct.pack(struct_format, len(utf8(name))) + utf8(name)
# The datatype is additionally given a final byte
# representing which type it is
encoded += struct.pack("!I", len(data_type)) + utf8(data_type)
encoded += struct.pack(struct_format, len(data_type)) + utf8(data_type)
encoded += TRANSPORT_TYPE_ENCODINGS[data_type]
if data_type == 'String' or data_type == 'Number':
@ -86,7 +84,7 @@ class Message(BaseModel):
# MD5 so as not to break client softwre
return('deadbeefdeadbeefdeadbeefdeadbeef')
encoded += struct.pack("!I", len(utf8(value))) + utf8(value)
encoded += struct.pack(struct_format, len(utf8(value))) + utf8(value)
md5.update(encoded)
return md5.hexdigest()
@ -150,8 +148,12 @@ class Queue(BaseModel):
camelcase_attributes = ['ApproximateNumberOfMessages',
'ApproximateNumberOfMessagesDelayed',
'ApproximateNumberOfMessagesNotVisible',
'ContentBasedDeduplication',
'CreatedTimestamp',
'DelaySeconds',
'FifoQueue',
'KmsDataKeyReusePeriodSeconds',
'KmsMasterKeyId',
'LastModifiedTimestamp',
'MaximumMessageSize',
'MessageRetentionPeriod',
@ -160,25 +162,35 @@ class Queue(BaseModel):
'VisibilityTimeout',
'WaitTimeSeconds']
def __init__(self, name, visibility_timeout, wait_time_seconds, region):
def __init__(self, name, region, **kwargs):
self.name = name
self.visibility_timeout = visibility_timeout or 30
self.visibility_timeout = int(kwargs.get('VisibilityTimeout', 30))
self.region = region
# wait_time_seconds will be set to immediate return messages
self.wait_time_seconds = int(wait_time_seconds) if wait_time_seconds else 0
self._messages = []
now = unix_time()
# kwargs can also have:
# [Policy, RedrivePolicy]
self.fifo_queue = kwargs.get('FifoQueue', 'false') == 'true'
self.content_based_deduplication = kwargs.get('ContentBasedDeduplication', 'false') == 'true'
self.kms_master_key_id = kwargs.get('KmsMasterKeyId', 'alias/aws/sqs')
self.kms_data_key_reuse_period_seconds = int(kwargs.get('KmsDataKeyReusePeriodSeconds', 300))
self.created_timestamp = now
self.delay_seconds = 0
self.delay_seconds = int(kwargs.get('DelaySeconds', 0))
self.last_modified_timestamp = now
self.maximum_message_size = 64 << 10
self.message_retention_period = 86400 * 4 # four days
self.queue_arn = 'arn:aws:sqs:{0}:123456789012:{1}'.format(
self.region, self.name)
self.receive_message_wait_time_seconds = 0
self.maximum_message_size = int(kwargs.get('MaximumMessageSize', 64 << 10))
self.message_retention_period = int(kwargs.get('MessageRetentionPeriod', 86400 * 4)) # four days
self.queue_arn = 'arn:aws:sqs:{0}:123456789012:{1}'.format(self.region, self.name)
self.receive_message_wait_time_seconds = int(kwargs.get('ReceiveMessageWaitTimeSeconds', 0))
# wait_time_seconds will be set to immediate return messages
self.wait_time_seconds = int(kwargs.get('WaitTimeSeconds', 0))
# Check some conditions
if self.fifo_queue and not self.name.endswith('.fifo'):
raise MessageAttributesInvalid('Queue name must end in .fifo for FIFO queues')
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -187,8 +199,8 @@ class Queue(BaseModel):
sqs_backend = sqs_backends[region_name]
return sqs_backend.create_queue(
name=properties['QueueName'],
visibility_timeout=properties.get('VisibilityTimeout'),
wait_time_seconds=properties.get('WaitTimeSeconds')
region=region_name,
**properties
)
@classmethod
@ -232,8 +244,10 @@ class Queue(BaseModel):
def attributes(self):
result = {}
for attribute in self.camelcase_attributes:
result[attribute] = getattr(
self, camelcase_to_underscores(attribute))
attr = getattr(self, camelcase_to_underscores(attribute))
if isinstance(attr, bool):
attr = str(attr).lower()
result[attribute] = attr
return result
def url(self, request_url):
@ -267,11 +281,14 @@ class SQSBackend(BaseBackend):
self.__dict__ = {}
self.__init__(region_name)
def create_queue(self, name, visibility_timeout, wait_time_seconds):
def create_queue(self, name, **kwargs):
queue = self.queues.get(name)
if queue is None:
queue = Queue(name, visibility_timeout,
wait_time_seconds, self.region_name)
try:
kwargs.pop('region')
except KeyError:
pass
queue = Queue(name, region=self.region_name, **kwargs)
self.queues[name] = queue
return queue

View File

@ -28,8 +28,7 @@ class SQSResponse(BaseResponse):
@property
def attribute(self):
if not hasattr(self, '_attribute'):
self._attribute = dict([(a['name'], a['value'])
for a in self._get_list_prefix('Attribute')])
self._attribute = self._get_map_prefix('Attribute', key_end='Name', value_end='Value')
return self._attribute
def _get_queue_name(self):
@ -58,17 +57,25 @@ class SQSResponse(BaseResponse):
return 404, headers, ERROR_INEXISTENT_QUEUE
return status_code, headers, body
def _error(self, code, message, status=400):
template = self.response_template(ERROR_TEMPLATE)
return template.render(code=code, message=message), dict(status=status)
def create_queue(self):
request_url = urlparse(self.uri)
queue_name = self.querystring.get("QueueName")[0]
queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=self.attribute.get('VisibilityTimeout'),
wait_time_seconds=self.attribute.get('WaitTimeSeconds'))
queue_name = self._get_param("QueueName")
try:
queue = self.sqs_backend.create_queue(queue_name, **self.attribute)
except MessageAttributesInvalid as e:
return self._error('InvalidParameterValue', e.description)
template = self.response_template(CREATE_QUEUE_RESPONSE)
return template.render(queue=queue, request_url=request_url)
def get_queue_url(self):
request_url = urlparse(self.uri)
queue_name = self.querystring.get("QueueName")[0]
queue_name = self._get_param("QueueName")
queue = self.sqs_backend.get_queue(queue_name)
if queue:
template = self.response_template(GET_QUEUE_URL_RESPONSE)
@ -78,14 +85,14 @@ class SQSResponse(BaseResponse):
def list_queues(self):
request_url = urlparse(self.uri)
queue_name_prefix = self.querystring.get("QueueNamePrefix", [None])[0]
queue_name_prefix = self._get_param('QueueNamePrefix')
queues = self.sqs_backend.list_queues(queue_name_prefix)
template = self.response_template(LIST_QUEUES_RESPONSE)
return template.render(queues=queues, request_url=request_url)
def change_message_visibility(self):
queue_name = self._get_queue_name()
receipt_handle = self.querystring.get("ReceiptHandle")[0]
receipt_handle = self._get_param('ReceiptHandle')
try:
visibility_timeout = self._get_validated_visibility_timeout()
@ -111,19 +118,15 @@ class SQSResponse(BaseResponse):
return template.render(queue=queue)
def set_queue_attributes(self):
# TODO validate self.get_param('QueueUrl')
queue_name = self._get_queue_name()
if "Attribute.Name" in self.querystring:
key = camelcase_to_underscores(
self.querystring.get("Attribute.Name")[0])
value = self.querystring.get("Attribute.Value")[0]
self.sqs_backend.set_queue_attribute(queue_name, key, value)
for a in self._get_list_prefix("Attribute"):
key = camelcase_to_underscores(a["name"])
value = a["value"]
for key, value in self.attribute.items():
key = camelcase_to_underscores(key)
self.sqs_backend.set_queue_attribute(queue_name, key, value)
return SET_QUEUE_ATTRIBUTE_RESPONSE
def delete_queue(self):
# TODO validate self.get_param('QueueUrl')
queue_name = self._get_queue_name()
queue = self.sqs_backend.delete_queue(queue_name)
if not queue:
@ -133,17 +136,12 @@ class SQSResponse(BaseResponse):
return template.render(queue=queue)
def send_message(self):
message = self.querystring.get("MessageBody")[0]
delay_seconds = self.querystring.get('DelaySeconds')
message = self._get_param('MessageBody')
delay_seconds = int(self._get_param('DelaySeconds', 0))
if len(message) > MAXIMUM_MESSAGE_LENGTH:
return ERROR_TOO_LONG_RESPONSE, dict(status=400)
if delay_seconds:
delay_seconds = int(delay_seconds[0])
else:
delay_seconds = 0
try:
message_attributes = parse_message_attributes(self.querystring)
except MessageAttributesInvalid as e:
@ -470,3 +468,13 @@ ERROR_INEXISTENT_QUEUE = """<ErrorResponse xmlns="http://queue.amazonaws.com/doc
</Error>
<RequestId>b8bc806b-fa6b-53b5-8be8-cfa2f9836bc3</RequestId>
</ErrorResponse>"""
ERROR_TEMPLATE = """<ErrorResponse xmlns="http://queue.amazonaws.com/doc/2012-11-05/">
<Error>
<Type>Sender</Type>
<Code>{{ code }}</Code>
<Message>{{ message }}</Message>
<Detail/>
</Error>
<RequestId>6fde8d1e-52cd-4581-8cd9-c512f4c64223</RequestId>
</ErrorResponse>"""

View File

@ -52,6 +52,16 @@ class SimpleSystemManagerBackend(BaseBackend):
except KeyError:
pass
def delete_parameters(self, names):
result = []
for name in names:
try:
del self._parameters[name]
result.append(name)
except KeyError:
pass
return result
def get_all_parameters(self):
result = []
for k, _ in self._parameters.items():
@ -65,6 +75,11 @@ class SimpleSystemManagerBackend(BaseBackend):
result.append(self._parameters[name])
return result
def get_parameter(self, name, with_decryption):
if name in self._parameters:
return self._parameters[name]
return None
def put_parameter(self, name, description, value, type, keyid, overwrite):
if not overwrite and name in self._parameters:
return

View File

@ -26,6 +26,40 @@ class SimpleSystemManagerResponse(BaseResponse):
self.ssm_backend.delete_parameter(name)
return json.dumps({})
def delete_parameters(self):
names = self._get_param('Names')
result = self.ssm_backend.delete_parameters(names)
response = {
'DeletedParameters': [],
'InvalidParameters': []
}
for name in names:
if name in result:
response['DeletedParameters'].append(name)
else:
response['InvalidParameters'].append(name)
return json.dumps(response)
def get_parameter(self):
name = self._get_param('Name')
with_decryption = self._get_param('WithDecryption')
result = self.ssm_backend.get_parameter(name, with_decryption)
if result is None:
error = {
'__type': 'ParameterNotFound',
'message': 'Parameter {0} not found.'.format(name)
}
return json.dumps(error), dict(status=400)
response = {
'Parameter': result.response_object(with_decryption)
}
return json.dumps(response)
def get_parameters(self):
names = self._get_param('Names')
with_decryption = self._get_param('WithDecryption')
@ -41,6 +75,10 @@ class SimpleSystemManagerResponse(BaseResponse):
param_data = parameter.response_object(with_decryption)
response['Parameters'].append(param_data)
param_names = [param.name for param in result]
for name in names:
if name not in param_names:
response['InvalidParameters'].append(name)
return json.dumps(response)
def describe_parameters(self):

6
moto/xray/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import xray_backends
from ..core.models import base_decorator
xray_backend = xray_backends['us-east-1']
mock_xray = base_decorator(xray_backends)

39
moto/xray/exceptions.py Normal file
View File

@ -0,0 +1,39 @@
import json
class AWSError(Exception):
CODE = None
STATUS = 400
def __init__(self, message, code=None, status=None):
self.message = message
self.code = code if code is not None else self.CODE
self.status = status if status is not None else self.STATUS
def response(self):
return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status)
class InvalidRequestException(AWSError):
CODE = 'InvalidRequestException'
class BadSegmentException(Exception):
def __init__(self, seg_id=None, code=None, message=None):
self.id = seg_id
self.code = code
self.message = message
def __repr__(self):
return '<BadSegment {0}>'.format('-'.join([self.id, self.code, self.message]))
def to_dict(self):
result = {}
if self.id is not None:
result['Id'] = self.id
if self.code is not None:
result['ErrorCode'] = self.code
if self.message is not None:
result['Message'] = self.message
return result

251
moto/xray/models.py Normal file
View File

@ -0,0 +1,251 @@
from __future__ import unicode_literals
import bisect
import datetime
from collections import defaultdict
import json
from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends
from .exceptions import BadSegmentException, AWSError
class TelemetryRecords(BaseModel):
def __init__(self, instance_id, hostname, resource_arn, records):
self.instance_id = instance_id
self.hostname = hostname
self.resource_arn = resource_arn
self.records = records
@classmethod
def from_json(cls, json):
instance_id = json.get('EC2InstanceId', None)
hostname = json.get('Hostname')
resource_arn = json.get('ResourceARN')
telemetry_records = json['TelemetryRecords']
return cls(instance_id, hostname, resource_arn, telemetry_records)
# https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html
class TraceSegment(BaseModel):
def __init__(self, name, segment_id, trace_id, start_time, raw, end_time=None, in_progress=False, service=None, user=None,
origin=None, parent_id=None, http=None, aws=None, metadata=None, annotations=None, subsegments=None, **kwargs):
self.name = name
self.id = segment_id
self.trace_id = trace_id
self._trace_version = None
self._original_request_start_time = None
self._trace_identifier = None
self.start_time = start_time
self._start_date = None
self.end_time = end_time
self._end_date = None
self.in_progress = in_progress
self.service = service
self.user = user
self.origin = origin
self.parent_id = parent_id
self.http = http
self.aws = aws
self.metadata = metadata
self.annotations = annotations
self.subsegments = subsegments
self.misc = kwargs
# Raw json string
self.raw = raw
def __lt__(self, other):
return self.start_date < other.start_date
@property
def trace_version(self):
if self._trace_version is None:
self._trace_version = int(self.trace_id.split('-', 1)[0])
return self._trace_version
@property
def request_start_date(self):
if self._original_request_start_time is None:
start_time = int(self.trace_id.split('-')[1], 16)
self._original_request_start_time = datetime.datetime.fromtimestamp(start_time)
return self._original_request_start_time
@property
def start_date(self):
if self._start_date is None:
self._start_date = datetime.datetime.fromtimestamp(self.start_time)
return self._start_date
@property
def end_date(self):
if self._end_date is None:
self._end_date = datetime.datetime.fromtimestamp(self.end_time)
return self._end_date
@classmethod
def from_dict(cls, data, raw):
# Check manditory args
if 'id' not in data:
raise BadSegmentException(code='MissingParam', message='Missing segment ID')
seg_id = data['id']
data['segment_id'] = seg_id # Just adding this key for future convenience
for arg in ('name', 'trace_id', 'start_time'):
if arg not in data:
raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing segment ID')
if 'end_time' not in data and 'in_progress' not in data:
raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time or in_progress')
if 'end_time' not in data and data['in_progress'] == 'false':
raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time')
return cls(raw=raw, **data)
class SegmentCollection(object):
def __init__(self):
self._traces = defaultdict(self._new_trace_item)
@staticmethod
def _new_trace_item():
return {
'start_date': datetime.datetime(1970, 1, 1),
'end_date': datetime.datetime(1970, 1, 1),
'finished': False,
'trace_id': None,
'segments': []
}
def put_segment(self, segment):
# insert into a sorted list
bisect.insort_left(self._traces[segment.trace_id]['segments'], segment)
# Get the last segment (takes into account incorrect ordering)
# and if its the last one, mark trace as complete
if self._traces[segment.trace_id]['segments'][-1].end_time is not None:
self._traces[segment.trace_id]['finished'] = True
start_time = self._traces[segment.trace_id]['segments'][0].start_date
end_time = self._traces[segment.trace_id]['segments'][-1].end_date
self._traces[segment.trace_id]['start_date'] = start_time
self._traces[segment.trace_id]['end_date'] = end_time
self._traces[segment.trace_id]['trace_id'] = segment.trace_id
# Todo consolidate trace segments into a trace.
# not enough working knowledge of xray to do this
def summary(self, start_time, end_time, filter_expression=None, sampling=False):
# This beast https://docs.aws.amazon.com/xray/latest/api/API_GetTraceSummaries.html#API_GetTraceSummaries_ResponseSyntax
if filter_expression is not None:
raise AWSError('Not implemented yet - moto', code='InternalFailure', status=500)
summaries = []
for tid, trace in self._traces.items():
if trace['finished'] and start_time < trace['start_date'] and trace['end_date'] < end_time:
duration = int((trace['end_date'] - trace['start_date']).total_seconds())
# this stuff is mostly guesses, refer to TODO above
has_error = any(['error' in seg.misc for seg in trace['segments']])
has_fault = any(['fault' in seg.misc for seg in trace['segments']])
has_throttle = any(['throttle' in seg.misc for seg in trace['segments']])
# Apparently all of these options are optional
summary_part = {
'Annotations': {}, # Not implemented yet
'Duration': duration,
'HasError': has_error,
'HasFault': has_fault,
'HasThrottle': has_throttle,
'Http': {}, # Not implemented yet
'Id': tid,
'IsParital': False, # needs lots more work to work on partials
'ResponseTime': 1, # definitely 1ms resposnetime
'ServiceIds': [], # Not implemented yet
'Users': {} # Not implemented yet
}
summaries.append(summary_part)
result = {
"ApproximateTime": int((datetime.datetime.now() - datetime.datetime(1970, 1, 1)).total_seconds()),
"TracesProcessedCount": len(summaries),
"TraceSummaries": summaries
}
return result
def get_trace_ids(self, trace_ids):
traces = []
unprocessed = []
# Its a default dict
existing_trace_ids = list(self._traces.keys())
for trace_id in trace_ids:
if trace_id in existing_trace_ids:
traces.append(self._traces[trace_id])
else:
unprocessed.append(trace_id)
return traces, unprocessed
class XRayBackend(BaseBackend):
def __init__(self):
self._telemetry_records = []
self._segment_collection = SegmentCollection()
def add_telemetry_records(self, json):
self._telemetry_records.append(
TelemetryRecords.from_json(json)
)
def process_segment(self, doc):
try:
data = json.loads(doc)
except ValueError:
raise BadSegmentException(code='JSONFormatError', message='Bad JSON data')
try:
# Get Segment Object
segment = TraceSegment.from_dict(data, raw=doc)
except ValueError:
raise BadSegmentException(code='JSONFormatError', message='Bad JSON data')
try:
# Store Segment Object
self._segment_collection.put_segment(segment)
except Exception as err:
raise BadSegmentException(seg_id=segment.id, code='InternalFailure', message=str(err))
def get_trace_summary(self, start_time, end_time, filter_expression, summaries):
return self._segment_collection.summary(start_time, end_time, filter_expression, summaries)
def get_trace_ids(self, trace_ids, next_token):
traces, unprocessed_ids = self._segment_collection.get_trace_ids(trace_ids)
result = {
'Traces': [],
'UnprocessedTraceIds': unprocessed_ids
}
for trace in traces:
segments = []
for segment in trace['segments']:
segments.append({
'Id': segment.id,
'Document': segment.raw
})
result['Traces'].append({
'Duration': int((trace['end_date'] - trace['start_date']).total_seconds()),
'Id': trace['trace_id'],
'Segments': segments
})
return result
xray_backends = {}
for region, ec2_backend in ec2_backends.items():
xray_backends[region] = XRayBackend()

150
moto/xray/responses.py Normal file
View File

@ -0,0 +1,150 @@
from __future__ import unicode_literals
import json
import datetime
from moto.core.responses import BaseResponse
from six.moves.urllib.parse import urlsplit
from .models import xray_backends
from .exceptions import AWSError, BadSegmentException
class XRayResponse(BaseResponse):
def _error(self, code, message):
return json.dumps({'__type': code, 'message': message}), dict(status=400)
@property
def xray_backend(self):
return xray_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param, default=None):
return self.request_params.get(param, default)
def _get_action(self):
# Amazon is just calling urls like /TelemetryRecords etc...
# This uses the value after / as the camalcase action, which then
# gets converted in call_action to find the following methods
return urlsplit(self.uri).path.lstrip('/')
# PutTelemetryRecords
def telemetry_records(self):
try:
self.xray_backend.add_telemetry_records(self.request_params)
except AWSError as err:
return err.response()
return ''
# PutTraceSegments
def trace_segments(self):
docs = self._get_param('TraceSegmentDocuments')
if docs is None:
msg = 'Parameter TraceSegmentDocuments is missing'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
# Raises an exception that contains info about a bad segment,
# the object also has a to_dict() method
bad_segments = []
for doc in docs:
try:
self.xray_backend.process_segment(doc)
except BadSegmentException as bad_seg:
bad_segments.append(bad_seg)
except Exception as err:
return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500)
result = {'UnprocessedTraceSegments': [x.to_dict() for x in bad_segments]}
return json.dumps(result)
# GetTraceSummaries
def trace_summaries(self):
start_time = self._get_param('StartTime')
end_time = self._get_param('EndTime')
if start_time is None:
msg = 'Parameter StartTime is missing'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
if end_time is None:
msg = 'Parameter EndTime is missing'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
filter_expression = self._get_param('FilterExpression')
sampling = self._get_param('Sampling', 'false') == 'true'
try:
start_time = datetime.datetime.fromtimestamp(int(start_time))
end_time = datetime.datetime.fromtimestamp(int(end_time))
except ValueError:
msg = 'start_time and end_time are not integers'
return json.dumps({'__type': 'InvalidParameterValue', 'message': msg}), dict(status=400)
except Exception as err:
return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500)
try:
result = self.xray_backend.get_trace_summary(start_time, end_time, filter_expression, sampling)
except AWSError as err:
return err.response()
except Exception as err:
return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500)
return json.dumps(result)
# BatchGetTraces
def traces(self):
trace_ids = self._get_param('TraceIds')
next_token = self._get_param('NextToken') # not implemented yet
if trace_ids is None:
msg = 'Parameter TraceIds is missing'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
try:
result = self.xray_backend.get_trace_ids(trace_ids, next_token)
except AWSError as err:
return err.response()
except Exception as err:
return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500)
return json.dumps(result)
# GetServiceGraph - just a dummy response for now
def service_graph(self):
start_time = self._get_param('StartTime')
end_time = self._get_param('EndTime')
# next_token = self._get_param('NextToken') # not implemented yet
if start_time is None:
msg = 'Parameter StartTime is missing'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
if end_time is None:
msg = 'Parameter EndTime is missing'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
result = {
'StartTime': start_time,
'EndTime': end_time,
'Services': []
}
return json.dumps(result)
# GetTraceGraph - just a dummy response for now
def trace_graph(self):
trace_ids = self._get_param('TraceIds')
# next_token = self._get_param('NextToken') # not implemented yet
if trace_ids is None:
msg = 'Parameter TraceIds is missing'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
result = {
'Services': []
}
return json.dumps(result)

15
moto/xray/urls.py Normal file
View File

@ -0,0 +1,15 @@
from __future__ import unicode_literals
from .responses import XRayResponse
url_bases = [
"https?://xray.(.+).amazonaws.com",
]
url_paths = {
'{0}/TelemetryRecords$': XRayResponse.dispatch,
'{0}/TraceSegments$': XRayResponse.dispatch,
'{0}/Traces$': XRayResponse.dispatch,
'{0}/ServiceGraph$': XRayResponse.dispatch,
'{0}/TraceGraph$': XRayResponse.dispatch,
'{0}/TraceSummaries$': XRayResponse.dispatch,
}

View File

@ -6,6 +6,12 @@ coverage
flake8
freezegun
flask
boto>=2.45.0
boto3>=1.4.4
botocore>=1.4.28
six
botocore>=1.5.77
six>=1.9
prompt-toolkit==1.0.14
click==6.7
inflection==0.3.1
lxml==4.0.0
beautifulsoup4==4.6.0

Some files were not shown because too many files have changed in this diff Show More