Merge branch 'master' of https://github.com/spulec/moto into spulec-master

This commit is contained in:
acsbendi 2019-07-13 13:43:19 +02:00
commit 419fcf2ee9
21 changed files with 1616 additions and 479 deletions

View File

@ -2,6 +2,10 @@
Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project. Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project.
## Running the tests locally
Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
## Is there a missing feature? ## Is there a missing feature?
Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services. Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services.

View File

@ -10,7 +10,7 @@ endif
init: init:
@python setup.py develop @python setup.py develop
@pip install -r requirements.txt @pip install -r requirements-dev.txt
lint: lint:
flake8 moto flake8 moto

View File

@ -3,7 +3,7 @@ import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL) # logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto' __title__ = 'moto'
__version__ = '1.3.9' __version__ = '1.3.11'
from .acm import mock_acm # flake8: noqa from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa

View File

@ -231,6 +231,10 @@ class LambdaFunction(BaseModel):
config.update({"VpcId": "vpc-123abc"}) config.update({"VpcId": "vpc-123abc"})
return config return config
@property
def physical_resource_id(self):
return self.function_name
def __repr__(self): def __repr__(self):
return json.dumps(self.get_configuration()) return json.dumps(self.get_configuration())

File diff suppressed because it is too large Load Diff

View File

@ -6,13 +6,16 @@ import decimal
import json import json
import re import re
import uuid import uuid
import six
import boto3 import boto3
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from .comparisons import get_comparison_func, get_filter_expression, Op from .comparisons import get_comparison_func
from .comparisons import get_filter_expression
from .comparisons import get_expected
from .exceptions import InvalidIndexNameError from .exceptions import InvalidIndexNameError
@ -68,10 +71,34 @@ class DynamoType(object):
except ValueError: except ValueError:
return float(self.value) return float(self.value)
elif self.is_set(): elif self.is_set():
return set(self.value) sub_type = self.type[0]
return set([DynamoType({sub_type: v}).cast_value
for v in self.value])
elif self.is_list():
return [DynamoType(v).cast_value for v in self.value]
elif self.is_map():
return dict([
(k, DynamoType(v).cast_value)
for k, v in self.value.items()])
else: else:
return self.value return self.value
def child_attr(self, key):
"""
Get Map or List children by key. str for Map, int for List.
Returns DynamoType or None.
"""
if isinstance(key, six.string_types) and self.is_map() and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, int) and self.is_list():
idx = key
if idx >= 0 and idx < len(self.value):
return DynamoType(self.value[idx])
return None
def to_json(self): def to_json(self):
return {self.type: self.value} return {self.type: self.value}
@ -89,6 +116,12 @@ class DynamoType(object):
def is_set(self): def is_set(self):
return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' return self.type == 'SS' or self.type == 'NS' or self.type == 'BS'
def is_list(self):
return self.type == 'L'
def is_map(self):
return self.type == 'M'
def same_type(self, other): def same_type(self, other):
return self.type == other.type return self.type == other.type
@ -504,7 +537,9 @@ class Table(BaseModel):
keys.append(range_key) keys.append(range_key)
return keys return keys
def put_item(self, item_attrs, expected=None, overwrite=False): def put_item(self, item_attrs, expected=None, condition_expression=None,
expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
if self.has_range_key: if self.has_range_key:
range_value = DynamoType(item_attrs.get(self.range_key_attr)) range_value = DynamoType(item_attrs.get(self.range_key_attr))
@ -527,29 +562,15 @@ class Table(BaseModel):
self.range_key_type, item_attrs) self.range_key_type, item_attrs)
if not overwrite: if not overwrite:
if current is None: if not get_expected(expected).expr(current):
current_attr = {} raise ValueError('The conditional request failed')
elif hasattr(current, 'attrs'): condition_op = get_filter_expression(
current_attr = current.attrs condition_expression,
else: expression_attribute_names,
current_attr = current expression_attribute_values)
if not condition_op.expr(current):
raise ValueError('The conditional request failed')
for key, val in expected.items():
if 'Exists' in val and val['Exists'] is False \
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
if key in current_attr:
raise ValueError("The conditional request failed")
elif key not in current_attr:
raise ValueError("The conditional request failed")
elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value:
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
dynamo_types = [
DynamoType(ele) for ele in
val.get("AttributeValueList", [])
]
if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types):
raise ValueError('The conditional request failed')
if range_value: if range_value:
self.items[hash_value][range_value] = item self.items[hash_value][range_value] = item
else: else:
@ -902,11 +923,15 @@ class DynamoDBBackend(BaseBackend):
table.global_indexes = list(gsis_by_name.values()) table.global_indexes = list(gsis_by_name.values())
return table return table
def put_item(self, table_name, item_attrs, expected=None, overwrite=False): def put_item(self, table_name, item_attrs, expected=None,
condition_expression=None, expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
table = self.tables.get(table_name) table = self.tables.get(table_name)
if not table: if not table:
return None return None
return table.put_item(item_attrs, expected, overwrite) return table.put_item(item_attrs, expected, condition_expression,
expression_attribute_names,
expression_attribute_values, overwrite)
def get_table_keys_name(self, table_name, keys): def get_table_keys_name(self, table_name, keys):
""" """
@ -962,10 +987,7 @@ class DynamoDBBackend(BaseBackend):
range_values = [DynamoType(range_value) range_values = [DynamoType(range_value)
for range_value in range_value_dicts] for range_value in range_value_dicts]
if filter_expression is not None: filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
return table.query(hash_key, range_comparison, range_values, limit, return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs)
@ -980,17 +1002,14 @@ class DynamoDBBackend(BaseBackend):
dynamo_types = [DynamoType(value) for value in comparison_values] dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types) scan_filters[key] = (comparison_operator, dynamo_types)
if filter_expression is not None: filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')]) projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')])
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression) return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression)
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected=None): expression_attribute_values, expected=None, condition_expression=None):
table = self.get_table(table_name) table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]): if all([table.hash_key_attr in key, table.range_key_attr in key]):
@ -1009,32 +1028,17 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value) item = table.get_item(hash_value, range_value)
if item is None:
item_attr = {}
elif hasattr(item, 'attrs'):
item_attr = item.attrs
else:
item_attr = item
if not expected: if not expected:
expected = {} expected = {}
for key, val in expected.items(): if not get_expected(expected).expr(item):
if 'Exists' in val and val['Exists'] is False \ raise ValueError('The conditional request failed')
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': condition_op = get_filter_expression(
if key in item_attr: condition_expression,
raise ValueError("The conditional request failed") expression_attribute_names,
elif key not in item_attr: expression_attribute_values)
raise ValueError("The conditional request failed") if not condition_op.expr(item):
elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value: raise ValueError('The conditional request failed')
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
dynamo_types = [
DynamoType(ele) for ele in
val.get("AttributeValueList", [])
]
if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types):
raise ValueError('The conditional request failed')
# Update does not fail on new items, so create one # Update does not fail on new items, so create one
if item is None: if item is None:

View File

@ -32,67 +32,6 @@ def get_empty_str_error():
)) ))
def condition_expression_to_expected(condition_expression, expression_attribute_names, expression_attribute_values):
"""
Limited condition expression syntax parsing.
Supports Global Negation ex: NOT(inner expressions).
Supports simple AND conditions ex: cond_a AND cond_b and cond_c.
Atomic expressions supported are attribute_exists(key), attribute_not_exists(key) and #key = :value.
"""
expected = {}
if condition_expression and 'OR' not in condition_expression:
reverse_re = re.compile('^NOT\s*\((.*)\)$')
reverse_m = reverse_re.match(condition_expression.strip())
reverse = False
if reverse_m:
reverse = True
condition_expression = reverse_m.group(1)
cond_items = [c.strip() for c in condition_expression.split('AND')]
if cond_items:
exists_re = re.compile('^attribute_exists\s*\((.*)\)$')
not_exists_re = re.compile(
'^attribute_not_exists\s*\((.*)\)$')
equals_re = re.compile('^(#?\w+)\s*=\s*(\:?\w+)')
for cond in cond_items:
exists_m = exists_re.match(cond)
not_exists_m = not_exists_re.match(cond)
equals_m = equals_re.match(cond)
if exists_m:
attribute_name = expression_attribute_names_lookup(exists_m.group(1), expression_attribute_names)
expected[attribute_name] = {'Exists': True if not reverse else False}
elif not_exists_m:
attribute_name = expression_attribute_names_lookup(not_exists_m.group(1), expression_attribute_names)
expected[attribute_name] = {'Exists': False if not reverse else True}
elif equals_m:
attribute_name = expression_attribute_names_lookup(equals_m.group(1), expression_attribute_names)
attribute_value = expression_attribute_values_lookup(equals_m.group(2), expression_attribute_values)
expected[attribute_name] = {
'AttributeValueList': [attribute_value],
'ComparisonOperator': 'EQ' if not reverse else 'NEQ'}
return expected
def expression_attribute_names_lookup(attribute_name, expression_attribute_names):
if attribute_name.startswith('#') and attribute_name in expression_attribute_names:
return expression_attribute_names[attribute_name]
else:
return attribute_name
def expression_attribute_values_lookup(attribute_value, expression_attribute_values):
if isinstance(attribute_value, six.string_types) and \
attribute_value.startswith(':') and\
attribute_value in expression_attribute_values:
return expression_attribute_values[attribute_value]
else:
return attribute_value
class DynamoHandler(BaseResponse): class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers): def get_endpoint_name(self, headers):
@ -288,18 +227,18 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected # Attempt to parse simple ConditionExpressions into an Expected
# expression # expression
if not expected: condition_expression = self.body.get('ConditionExpression')
condition_expression = self.body.get('ConditionExpression') expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression, if condition_expression:
expression_attribute_names, overwrite = False
expression_attribute_values)
if expected:
overwrite = False
try: try:
result = self.dynamodb_backend.put_item(name, item, expected, overwrite) result = self.dynamodb_backend.put_item(
name, item, expected, condition_expression,
expression_attribute_names, expression_attribute_values,
overwrite)
except ValueError: except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er, 'A condition specified in the operation could not be evaluated.') return self.error(er, 'A condition specified in the operation could not be evaluated.')
@ -626,7 +565,7 @@ class DynamoHandler(BaseResponse):
name = self.body['TableName'] name = self.body['TableName']
key = self.body['Key'] key = self.body['Key']
return_values = self.body.get('ReturnValues', 'NONE') return_values = self.body.get('ReturnValues', 'NONE')
update_expression = self.body.get('UpdateExpression') update_expression = self.body.get('UpdateExpression', '').strip()
attribute_updates = self.body.get('AttributeUpdates') attribute_updates = self.body.get('AttributeUpdates')
expression_attribute_names = self.body.get( expression_attribute_names = self.body.get(
'ExpressionAttributeNames', {}) 'ExpressionAttributeNames', {})
@ -653,13 +592,9 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected # Attempt to parse simple ConditionExpressions into an Expected
# expression # expression
if not expected: condition_expression = self.body.get('ConditionExpression')
condition_expression = self.body.get('ConditionExpression') expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression,
expression_attribute_names,
expression_attribute_values)
# Support spaces between operators in an update expression # Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c` # E.g. `a = b + c` -> `a=b+c`
@ -670,7 +605,7 @@ class DynamoHandler(BaseResponse):
try: try:
item = self.dynamodb_backend.update_item( item = self.dynamodb_backend.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names, name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected expression_attribute_values, expected, condition_expression
) )
except ValueError: except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'

View File

@ -60,6 +60,15 @@ class DBParameterGroupNotFoundError(RDSClientError):
'DB Parameter Group {0} not found.'.format(db_parameter_group_name)) 'DB Parameter Group {0} not found.'.format(db_parameter_group_name))
class OptionGroupNotFoundFaultError(RDSClientError):
def __init__(self, option_group_name):
super(OptionGroupNotFoundFaultError, self).__init__(
'OptionGroupNotFoundFault',
'Specified OptionGroupName: {0} not found.'.format(option_group_name)
)
class InvalidDBClusterStateFaultError(RDSClientError): class InvalidDBClusterStateFaultError(RDSClientError):
def __init__(self, database_identifier): def __init__(self, database_identifier):

View File

@ -20,6 +20,7 @@ from .exceptions import (RDSClientError,
DBSecurityGroupNotFoundError, DBSecurityGroupNotFoundError,
DBSubnetGroupNotFoundError, DBSubnetGroupNotFoundError,
DBParameterGroupNotFoundError, DBParameterGroupNotFoundError,
OptionGroupNotFoundFaultError,
InvalidDBClusterStateFaultError, InvalidDBClusterStateFaultError,
InvalidDBInstanceStateError, InvalidDBInstanceStateError,
SnapshotQuotaExceededError, SnapshotQuotaExceededError,
@ -70,6 +71,7 @@ class Database(BaseModel):
self.port = Database.default_port(self.engine) self.port = Database.default_port(self.engine)
self.db_instance_identifier = kwargs.get('db_instance_identifier') self.db_instance_identifier = kwargs.get('db_instance_identifier')
self.db_name = kwargs.get("db_name") self.db_name = kwargs.get("db_name")
self.instance_create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
self.publicly_accessible = kwargs.get("publicly_accessible") self.publicly_accessible = kwargs.get("publicly_accessible")
if self.publicly_accessible is None: if self.publicly_accessible is None:
self.publicly_accessible = True self.publicly_accessible = True
@ -99,6 +101,8 @@ class Database(BaseModel):
'preferred_backup_window', '13:14-13:44') 'preferred_backup_window', '13:14-13:44')
self.license_model = kwargs.get('license_model', 'general-public-license') self.license_model = kwargs.get('license_model', 'general-public-license')
self.option_group_name = kwargs.get('option_group_name', None) self.option_group_name = kwargs.get('option_group_name', None)
if self.option_group_name and self.option_group_name not in rds2_backends[self.region].option_groups:
raise OptionGroupNotFoundFaultError(self.option_group_name)
self.default_option_groups = {"MySQL": "default.mysql5.6", self.default_option_groups = {"MySQL": "default.mysql5.6",
"mysql": "default.mysql5.6", "mysql": "default.mysql5.6",
"postgres": "default.postgres9.3" "postgres": "default.postgres9.3"
@ -148,6 +152,7 @@ class Database(BaseModel):
<VpcSecurityGroups/> <VpcSecurityGroups/>
<DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier> <DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier>
<DbiResourceId>{{ database.dbi_resource_id }}</DbiResourceId> <DbiResourceId>{{ database.dbi_resource_id }}</DbiResourceId>
<InstanceCreateTime>{{ database.instance_create_time }}</InstanceCreateTime>
<PreferredBackupWindow>03:50-04:20</PreferredBackupWindow> <PreferredBackupWindow>03:50-04:20</PreferredBackupWindow>
<PreferredMaintenanceWindow>wed:06:38-wed:07:08</PreferredMaintenanceWindow> <PreferredMaintenanceWindow>wed:06:38-wed:07:08</PreferredMaintenanceWindow>
<ReadReplicaDBInstanceIdentifiers> <ReadReplicaDBInstanceIdentifiers>
@ -173,6 +178,10 @@ class Database(BaseModel):
<LicenseModel>{{ database.license_model }}</LicenseModel> <LicenseModel>{{ database.license_model }}</LicenseModel>
<EngineVersion>{{ database.engine_version }}</EngineVersion> <EngineVersion>{{ database.engine_version }}</EngineVersion>
<OptionGroupMemberships> <OptionGroupMemberships>
<OptionGroupMembership>
<OptionGroupName>{{ database.option_group_name }}</OptionGroupName>
<Status>in-sync</Status>
</OptionGroupMembership>
</OptionGroupMemberships> </OptionGroupMemberships>
<DBParameterGroups> <DBParameterGroups>
{% for db_parameter_group in database.db_parameter_groups() %} {% for db_parameter_group in database.db_parameter_groups() %}
@ -373,7 +382,7 @@ class Database(BaseModel):
"Address": "{{ database.address }}", "Address": "{{ database.address }}",
"Port": "{{ database.port }}" "Port": "{{ database.port }}"
}, },
"InstanceCreateTime": null, "InstanceCreateTime": "{{ database.instance_create_time }}",
"Iops": null, "Iops": null,
"ReadReplicaDBInstanceIdentifiers": [{%- for replica in database.replicas -%} "ReadReplicaDBInstanceIdentifiers": [{%- for replica in database.replicas -%}
{%- if not loop.first -%},{%- endif -%} {%- if not loop.first -%},{%- endif -%}
@ -873,13 +882,16 @@ class RDS2Backend(BaseBackend):
def create_option_group(self, option_group_kwargs): def create_option_group(self, option_group_kwargs):
option_group_id = option_group_kwargs['name'] option_group_id = option_group_kwargs['name']
valid_option_group_engines = {'mysql': ['5.6'], valid_option_group_engines = {'mariadb': ['10.0', '10.1', '10.2', '10.3'],
'oracle-se1': ['11.2'], 'mysql': ['5.5', '5.6', '5.7', '8.0'],
'oracle-se': ['11.2'], 'oracle-se2': ['11.2', '12.1', '12.2'],
'oracle-ee': ['11.2'], 'oracle-se1': ['11.2', '12.1', '12.2'],
'oracle-se': ['11.2', '12.1', '12.2'],
'oracle-ee': ['11.2', '12.1', '12.2'],
'sqlserver-se': ['10.50', '11.00'], 'sqlserver-se': ['10.50', '11.00'],
'sqlserver-ee': ['10.50', '11.00'] 'sqlserver-ee': ['10.50', '11.00'],
} 'sqlserver-ex': ['10.50', '11.00'],
'sqlserver-web': ['10.50', '11.00']}
if option_group_kwargs['name'] in self.option_groups: if option_group_kwargs['name'] in self.option_groups:
raise RDSClientError('OptionGroupAlreadyExistsFault', raise RDSClientError('OptionGroupAlreadyExistsFault',
'An option group named {0} already exists.'.format(option_group_kwargs['name'])) 'An option group named {0} already exists.'.format(option_group_kwargs['name']))
@ -905,8 +917,7 @@ class RDS2Backend(BaseBackend):
if option_group_name in self.option_groups: if option_group_name in self.option_groups:
return self.option_groups.pop(option_group_name) return self.option_groups.pop(option_group_name)
else: else:
raise RDSClientError( raise OptionGroupNotFoundFaultError(option_group_name)
'OptionGroupNotFoundFault', 'Specified OptionGroupName: {0} not found.'.format(option_group_name))
def describe_option_groups(self, option_group_kwargs): def describe_option_groups(self, option_group_kwargs):
option_group_list = [] option_group_list = []
@ -935,8 +946,7 @@ class RDS2Backend(BaseBackend):
else: else:
option_group_list.append(option_group) option_group_list.append(option_group)
if not len(option_group_list): if not len(option_group_list):
raise RDSClientError('OptionGroupNotFoundFault', raise OptionGroupNotFoundFaultError(option_group_kwargs['name'])
'Specified OptionGroupName: {0} not found.'.format(option_group_kwargs['name']))
return option_group_list[marker:max_records + marker] return option_group_list[marker:max_records + marker]
@staticmethod @staticmethod
@ -965,8 +975,7 @@ class RDS2Backend(BaseBackend):
def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None): def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None):
if option_group_name not in self.option_groups: if option_group_name not in self.option_groups:
raise RDSClientError('OptionGroupNotFoundFault', raise OptionGroupNotFoundFaultError(option_group_name)
'Specified OptionGroupName: {0} not found.'.format(option_group_name))
if not options_to_include and not options_to_remove: if not options_to_include and not options_to_remove:
raise RDSClientError('InvalidParameterValue', raise RDSClientError('InvalidParameterValue',
'At least one option must be added, modified, or removed.') 'At least one option must be added, modified, or removed.')

View File

@ -34,7 +34,7 @@ class RDS2Response(BaseResponse):
"master_user_password": self._get_param('MasterUserPassword'), "master_user_password": self._get_param('MasterUserPassword'),
"master_username": self._get_param('MasterUsername'), "master_username": self._get_param('MasterUsername'),
"multi_az": self._get_bool_param("MultiAZ"), "multi_az": self._get_bool_param("MultiAZ"),
# OptionGroupName "option_group_name": self._get_param("OptionGroupName"),
"port": self._get_param('Port'), "port": self._get_param('Port'),
# PreferredBackupWindow # PreferredBackupWindow
# PreferredMaintenanceWindow # PreferredMaintenanceWindow

View File

@ -85,6 +85,7 @@ class RecordSet(BaseModel):
self.health_check = kwargs.get('HealthCheckId') self.health_check = kwargs.get('HealthCheckId')
self.hosted_zone_name = kwargs.get('HostedZoneName') self.hosted_zone_name = kwargs.get('HostedZoneName')
self.hosted_zone_id = kwargs.get('HostedZoneId') self.hosted_zone_id = kwargs.get('HostedZoneId')
self.alias_target = kwargs.get('AliasTarget')
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -143,6 +144,13 @@ class RecordSet(BaseModel):
{% if record_set.ttl %} {% if record_set.ttl %}
<TTL>{{ record_set.ttl }}</TTL> <TTL>{{ record_set.ttl }}</TTL>
{% endif %} {% endif %}
{% if record_set.alias_target %}
<AliasTarget>
<HostedZoneId>{{ record_set.alias_target['HostedZoneId'] }}</HostedZoneId>
<DNSName>{{ record_set.alias_target['DNSName'] }}</DNSName>
<EvaluateTargetHealth>{{ record_set.alias_target['EvaluateTargetHealth'] }}</EvaluateTargetHealth>
</AliasTarget>
{% else %}
<ResourceRecords> <ResourceRecords>
{% for record in record_set.records %} {% for record in record_set.records %}
<ResourceRecord> <ResourceRecord>
@ -150,6 +158,7 @@ class RecordSet(BaseModel):
</ResourceRecord> </ResourceRecord>
{% endfor %} {% endfor %}
</ResourceRecords> </ResourceRecords>
{% endif %}
{% if record_set.health_check %} {% if record_set.health_check %}
<HealthCheckId>{{ record_set.health_check }}</HealthCheckId> <HealthCheckId>{{ record_set.health_check }}</HealthCheckId>
{% endif %} {% endif %}

View File

@ -134,10 +134,7 @@ class Route53(BaseResponse):
# Depending on how many records there are, this may # Depending on how many records there are, this may
# or may not be a list # or may not be a list
resource_records = [resource_records] resource_records = [resource_records]
record_values = [x['Value'] for x in resource_records] record_set['ResourceRecords'] = [x['Value'] for x in resource_records]
elif 'AliasTarget' in record_set:
record_values = [record_set['AliasTarget']['DNSName']]
record_set['ResourceRecords'] = record_values
if action == 'CREATE': if action == 'CREATE':
the_zone.add_rrset(record_set) the_zone.add_rrset(record_set)
else: else:

View File

@ -807,7 +807,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
body = b'' body = b''
if method == 'GET': if method == 'GET':
return self._key_response_get(bucket_name, query, key_name, headers) return self._key_response_get(bucket_name, query, key_name, headers=request.headers)
elif method == 'PUT': elif method == 'PUT':
return self._key_response_put(request, body, bucket_name, query, key_name, headers) return self._key_response_put(request, body, bucket_name, query, key_name, headers)
elif method == 'HEAD': elif method == 'HEAD':
@ -842,10 +842,15 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
parts=parts parts=parts
) )
version_id = query.get('versionId', [None])[0] version_id = query.get('versionId', [None])[0]
if_modified_since = headers.get('If-Modified-Since', None)
key = self.backend.get_key( key = self.backend.get_key(
bucket_name, key_name, version_id=version_id) bucket_name, key_name, version_id=version_id)
if key is None: if key is None:
raise MissingKey(key_name) raise MissingKey(key_name)
if if_modified_since:
if_modified_since = str_to_rfc_1123_datetime(if_modified_since)
if if_modified_since and key.last_modified < if_modified_since:
return 304, response_headers, 'Not Modified'
if 'acl' in query: if 'acl' in query:
template = self.response_template(S3_OBJECT_ACL_RESPONSE) template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return 200, response_headers, template.render(obj=key) return 200, response_headers, template.render(obj=key)

81
moto/ses/feedback.py Normal file
View File

@ -0,0 +1,81 @@
"""
SES Feedback messages
Extracted from https://docs.aws.amazon.com/ses/latest/DeveloperGuide/notification-contents.html
"""
COMMON_MAIL = {
"notificationType": "Bounce, Complaint, or Delivery.",
"mail": {
"timestamp": "2018-10-08T14:05:45 +0000",
"messageId": "000001378603177f-7a5433e7-8edb-42ae-af10-f0181f34d6ee-000000",
"source": "sender@example.com",
"sourceArn": "arn:aws:ses:us-west-2:888888888888:identity/example.com",
"sourceIp": "127.0.3.0",
"sendingAccountId": "123456789012",
"destination": [
"recipient@example.com"
],
"headersTruncated": False,
"headers": [
{
"name": "From",
"value": "\"Sender Name\" <sender@example.com>"
},
{
"name": "To",
"value": "\"Recipient Name\" <recipient@example.com>"
}
],
"commonHeaders": {
"from": [
"Sender Name <sender@example.com>"
],
"date": "Mon, 08 Oct 2018 14:05:45 +0000",
"to": [
"Recipient Name <recipient@example.com>"
],
"messageId": " custom-message-ID",
"subject": "Message sent using Amazon SES"
}
}
}
BOUNCE = {
"bounceType": "Permanent",
"bounceSubType": "General",
"bouncedRecipients": [
{
"status": "5.0.0",
"action": "failed",
"diagnosticCode": "smtp; 550 user unknown",
"emailAddress": "recipient1@example.com"
},
{
"status": "4.0.0",
"action": "delayed",
"emailAddress": "recipient2@example.com"
}
],
"reportingMTA": "example.com",
"timestamp": "2012-05-25T14:59:38.605Z",
"feedbackId": "000001378603176d-5a4b5ad9-6f30-4198-a8c3-b1eb0c270a1d-000000",
"remoteMtaIp": "127.0.2.0"
}
COMPLAINT = {
"userAgent": "AnyCompany Feedback Loop (V0.01)",
"complainedRecipients": [
{
"emailAddress": "recipient1@example.com"
}
],
"complaintFeedbackType": "abuse",
"arrivalDate": "2009-12-03T04:24:21.000-05:00",
"timestamp": "2012-05-25T14:59:38.623Z",
"feedbackId": "000001378603177f-18c07c78-fa81-4a58-9dd1-fedc3cb8f49a-000000"
}
DELIVERY = {
"timestamp": "2014-05-28T22:41:01.184Z",
"processingTimeMillis": 546,
"recipients": ["success@simulator.amazonses.com"],
"smtpResponse": "250 ok: Message 64111812 accepted",
"reportingMTA": "a8-70.smtp-out.amazonses.com",
"remoteMtaIp": "127.0.2.0"
}

View File

@ -4,13 +4,41 @@ import email
from email.utils import parseaddr from email.utils import parseaddr
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.sns.models import sns_backends
from .exceptions import MessageRejectedError from .exceptions import MessageRejectedError
from .utils import get_random_message_id from .utils import get_random_message_id
from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY
RECIPIENT_LIMIT = 50 RECIPIENT_LIMIT = 50
class SESFeedback(BaseModel):
BOUNCE = "Bounce"
COMPLAINT = "Complaint"
DELIVERY = "Delivery"
SUCCESS_ADDR = "success"
BOUNCE_ADDR = "bounce"
COMPLAINT_ADDR = "complaint"
FEEDBACK_SUCCESS_MSG = {"test": "success"}
FEEDBACK_BOUNCE_MSG = {"test": "bounce"}
FEEDBACK_COMPLAINT_MSG = {"test": "complaint"}
@staticmethod
def generate_message(msg_type):
msg = dict(COMMON_MAIL)
if msg_type == SESFeedback.BOUNCE:
msg["bounce"] = BOUNCE
elif msg_type == SESFeedback.COMPLAINT:
msg["complaint"] = COMPLAINT
elif msg_type == SESFeedback.DELIVERY:
msg["delivery"] = DELIVERY
return msg
class Message(BaseModel): class Message(BaseModel):
def __init__(self, message_id, source, subject, body, destinations): def __init__(self, message_id, source, subject, body, destinations):
@ -48,6 +76,7 @@ class SESBackend(BaseBackend):
self.domains = [] self.domains = []
self.sent_messages = [] self.sent_messages = []
self.sent_message_count = 0 self.sent_message_count = 0
self.sns_topics = {}
def _is_verified_address(self, source): def _is_verified_address(self, source):
_, address = parseaddr(source) _, address = parseaddr(source)
@ -77,7 +106,7 @@ class SESBackend(BaseBackend):
else: else:
self.domains.remove(identity) self.domains.remove(identity)
def send_email(self, source, subject, body, destinations): def send_email(self, source, subject, body, destinations, region):
recipient_count = sum(map(len, destinations.values())) recipient_count = sum(map(len, destinations.values()))
if recipient_count > RECIPIENT_LIMIT: if recipient_count > RECIPIENT_LIMIT:
raise MessageRejectedError('Too many recipients.') raise MessageRejectedError('Too many recipients.')
@ -86,13 +115,46 @@ class SESBackend(BaseBackend):
"Email address not verified %s" % source "Email address not verified %s" % source
) )
self.__process_sns_feedback__(source, destinations, region)
message_id = get_random_message_id() message_id = get_random_message_id()
message = Message(message_id, source, subject, body, destinations) message = Message(message_id, source, subject, body, destinations)
self.sent_messages.append(message) self.sent_messages.append(message)
self.sent_message_count += recipient_count self.sent_message_count += recipient_count
return message return message
def send_raw_email(self, source, destinations, raw_data): def __type_of_message__(self, destinations):
"""Checks the destination for any special address that could indicate delivery, complaint or bounce
like in SES simualtor"""
alladdress = destinations.get("ToAddresses", []) + destinations.get("CcAddresses", []) + destinations.get("BccAddresses", [])
for addr in alladdress:
if SESFeedback.SUCCESS_ADDR in addr:
return SESFeedback.DELIVERY
elif SESFeedback.COMPLAINT_ADDR in addr:
return SESFeedback.COMPLAINT
elif SESFeedback.BOUNCE_ADDR in addr:
return SESFeedback.BOUNCE
return None
def __generate_feedback__(self, msg_type):
"""Generates the SNS message for the feedback"""
return SESFeedback.generate_message(msg_type)
def __process_sns_feedback__(self, source, destinations, region):
domain = str(source)
if "@" in domain:
domain = domain.split("@")[1]
if domain in self.sns_topics:
msg_type = self.__type_of_message__(destinations)
if msg_type is not None:
sns_topic = self.sns_topics[domain].get(msg_type, None)
if sns_topic is not None:
message = self.__generate_feedback__(msg_type)
if message:
sns_backends[region].publish(sns_topic, message)
def send_raw_email(self, source, destinations, raw_data, region):
if source is not None: if source is not None:
_, source_email_address = parseaddr(source) _, source_email_address = parseaddr(source)
if source_email_address not in self.addresses: if source_email_address not in self.addresses:
@ -122,6 +184,8 @@ class SESBackend(BaseBackend):
if recipient_count > RECIPIENT_LIMIT: if recipient_count > RECIPIENT_LIMIT:
raise MessageRejectedError('Too many recipients.') raise MessageRejectedError('Too many recipients.')
self.__process_sns_feedback__(source, destinations, region)
self.sent_message_count += recipient_count self.sent_message_count += recipient_count
message_id = get_random_message_id() message_id = get_random_message_id()
message = RawMessage(message_id, source, destinations, raw_data) message = RawMessage(message_id, source, destinations, raw_data)
@ -131,5 +195,16 @@ class SESBackend(BaseBackend):
def get_send_quota(self): def get_send_quota(self):
return SESQuota(self.sent_message_count) return SESQuota(self.sent_message_count)
def set_identity_notification_topic(self, identity, notification_type, sns_topic):
identity_sns_topics = self.sns_topics.get(identity, {})
if sns_topic is None:
del identity_sns_topics[notification_type]
else:
identity_sns_topics[notification_type] = sns_topic
self.sns_topics[identity] = identity_sns_topics
return {}
ses_backend = SESBackend() ses_backend = SESBackend()

View File

@ -70,7 +70,7 @@ class EmailResponse(BaseResponse):
break break
destinations[dest_type].append(address[0]) destinations[dest_type].append(address[0])
message = ses_backend.send_email(source, subject, body, destinations) message = ses_backend.send_email(source, subject, body, destinations, self.region)
template = self.response_template(SEND_EMAIL_RESPONSE) template = self.response_template(SEND_EMAIL_RESPONSE)
return template.render(message=message) return template.render(message=message)
@ -92,7 +92,7 @@ class EmailResponse(BaseResponse):
break break
destinations.append(address[0]) destinations.append(address[0])
message = ses_backend.send_raw_email(source, destinations, raw_data) message = ses_backend.send_raw_email(source, destinations, raw_data, self.region)
template = self.response_template(SEND_RAW_EMAIL_RESPONSE) template = self.response_template(SEND_RAW_EMAIL_RESPONSE)
return template.render(message=message) return template.render(message=message)
@ -101,6 +101,18 @@ class EmailResponse(BaseResponse):
template = self.response_template(GET_SEND_QUOTA_RESPONSE) template = self.response_template(GET_SEND_QUOTA_RESPONSE)
return template.render(quota=quota) return template.render(quota=quota)
def set_identity_notification_topic(self):
identity = self.querystring.get("Identity")[0]
not_type = self.querystring.get("NotificationType")[0]
sns_topic = self.querystring.get("SnsTopic")
if sns_topic:
sns_topic = sns_topic[0]
ses_backend.set_identity_notification_topic(identity, not_type, sns_topic)
template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE)
return template.render()
VERIFY_EMAIL_IDENTITY = """<VerifyEmailIdentityResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/"> VERIFY_EMAIL_IDENTITY = """<VerifyEmailIdentityResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<VerifyEmailIdentityResult/> <VerifyEmailIdentityResult/>
@ -200,3 +212,10 @@ GET_SEND_QUOTA_RESPONSE = """<GetSendQuotaResponse xmlns="http://ses.amazonaws.c
<RequestId>273021c6-c866-11e0-b926-699e21c3af9e</RequestId> <RequestId>273021c6-c866-11e0-b926-699e21c3af9e</RequestId>
</ResponseMetadata> </ResponseMetadata>
</GetSendQuotaResponse>""" </GetSendQuotaResponse>"""
SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE = """<SetIdentityNotificationTopicResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<SetIdentityNotificationTopicResult/>
<ResponseMetadata>
<RequestId>47e0ef1a-9bf2-11e1-9279-0100e8cf109a</RequestId>
</ResponseMetadata>
</SetIdentityNotificationTopicResponse>"""

View File

@ -838,44 +838,47 @@ def test_filter_expression():
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
# NOT test 2 # NOT test 2
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': 8}}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': '8'}})
filter_expr.expr(row1).should.be(False) # Id = 8 so should be false filter_expr.expr(row1).should.be(False) # Id = 8 so should be false
# AND test # AND test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 7}}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '7'}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
filter_expr.expr(row2).should.be(False) filter_expr.expr(row2).should.be(False)
# OR test # OR test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': 5}, ':v1': {'N': 8}}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '8'}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
# BETWEEN test # BETWEEN test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 10}}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '10'}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
# PAREN test # PAREN test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': 8}, ':v1': {'N': 5}}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': '8'}, ':v1': {'N': '5'}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
# IN test # IN test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN :v0', {}, {':v0': {'NS': [7, 8, 9]}}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN (:v0, :v1, :v2)', {}, {
':v0': {'N': '7'},
':v1': {'N': '8'},
':v2': {'N': '9'}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
# attribute function tests (with extra spaces) # attribute function tests (with extra spaces)
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, N)', {}, {}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, :v0)', {}, {':v0': {'S': 'N'}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
# beginswith function test # beginswith function test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, Some)', {}, {}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, :v0)', {}, {':v0': {'S': 'Some'}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
filter_expr.expr(row2).should.be(False) filter_expr.expr(row2).should.be(False)
# contains function test # contains function test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, test1)', {}, {}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, :v0)', {}, {':v0': {'S': 'test1'}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
filter_expr.expr(row2).should.be(False) filter_expr.expr(row2).should.be(False)
@ -916,14 +919,26 @@ def test_query_filter():
TableName='test1', TableName='test1',
Item={ Item={
'client': {'S': 'client1'}, 'client': {'S': 'client1'},
'app': {'S': 'app1'} 'app': {'S': 'app1'},
'nested': {'M': {
'version': {'S': 'version1'},
'contents': {'L': [
{'S': 'value1'}, {'S': 'value2'},
]},
}},
} }
) )
client.put_item( client.put_item(
TableName='test1', TableName='test1',
Item={ Item={
'client': {'S': 'client1'}, 'client': {'S': 'client1'},
'app': {'S': 'app2'} 'app': {'S': 'app2'},
'nested': {'M': {
'version': {'S': 'version2'},
'contents': {'L': [
{'S': 'value1'}, {'S': 'value2'},
]},
}},
} }
) )
@ -945,6 +960,18 @@ def test_query_filter():
) )
assert response['Count'] == 2 assert response['Count'] == 2
response = table.query(
KeyConditionExpression=Key('client').eq('client1'),
FilterExpression=Attr('nested.version').contains('version')
)
assert response['Count'] == 2
response = table.query(
KeyConditionExpression=Key('client').eq('client1'),
FilterExpression=Attr('nested.contents[0]').eq('value1')
)
assert response['Count'] == 2
@mock_dynamodb2 @mock_dynamodb2
def test_scan_filter(): def test_scan_filter():
@ -1698,7 +1725,6 @@ def test_dynamodb_streams_2():
@mock_dynamodb2 @mock_dynamodb2
def test_condition_expressions(): def test_condition_expressions():
client = boto3.client('dynamodb', region_name='us-east-1') client = boto3.client('dynamodb', region_name='us-east-1')
dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
# Create the DynamoDB table. # Create the DynamoDB table.
client.create_table( client.create_table(
@ -1751,6 +1777,57 @@ def test_condition_expressions():
} }
) )
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
'app': {'S': 'app1'},
'match': {'S': 'match'},
'existing': {'S': 'existing'},
},
ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)',
ExpressionAttributeNames={
'#nonexistent': 'nope',
'#existing': 'existing'
}
)
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
'app': {'S': 'app1'},
'match': {'S': 'match'},
'existing': {'S': 'existing'},
},
ConditionExpression='#client BETWEEN :a AND :z',
ExpressionAttributeNames={
'#client': 'client',
},
ExpressionAttributeValues={
':a': {'S': 'a'},
':z': {'S': 'z'},
}
)
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
'app': {'S': 'app1'},
'match': {'S': 'match'},
'existing': {'S': 'existing'},
},
ConditionExpression='#client IN (:client1, :client2)',
ExpressionAttributeNames={
'#client': 'client',
},
ExpressionAttributeValues={
':client1': {'S': 'client1'},
':client2': {'S': 'client2'},
}
)
with assert_raises(client.exceptions.ConditionalCheckFailedException): with assert_raises(client.exceptions.ConditionalCheckFailedException):
client.put_item( client.put_item(
TableName='test1', TableName='test1',
@ -1803,6 +1880,89 @@ def test_condition_expressions():
} }
) )
# Make sure update_item honors ConditionExpression as well
client.update_item(
TableName='test1',
Key={
'client': {'S': 'client1'},
'app': {'S': 'app1'},
},
UpdateExpression='set #match=:match',
ConditionExpression='attribute_exists(#existing)',
ExpressionAttributeNames={
'#existing': 'existing',
'#match': 'match',
},
ExpressionAttributeValues={
':match': {'S': 'match'}
}
)
with assert_raises(client.exceptions.ConditionalCheckFailedException):
client.update_item(
TableName='test1',
Key={
'client': { 'S': 'client1'},
'app': { 'S': 'app1'},
},
UpdateExpression='set #match=:match',
ConditionExpression='attribute_not_exists(#existing)',
ExpressionAttributeValues={
':match': {'S': 'match'}
},
ExpressionAttributeNames={
'#existing': 'existing',
'#match': 'match',
},
)
@mock_dynamodb2
def test_condition_expression__attr_doesnt_exist():
client = boto3.client('dynamodb', region_name='us-east-1')
client.create_table(
TableName='test',
KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}],
AttributeDefinitions=[
{'AttributeName': 'forum_name', 'AttributeType': 'S'},
],
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1},
)
client.put_item(
TableName='test',
Item={
'forum_name': {'S': 'foo'},
'ttl': {'N': 'bar'},
}
)
def update_if_attr_doesnt_exist():
# Test nonexistent top-level attribute.
client.update_item(
TableName='test',
Key={
'forum_name': {'S': 'the-key'},
'subject': {'S': 'the-subject'},
},
UpdateExpression='set #new_state=:new_state, #ttl=:ttl',
ConditionExpression='attribute_not_exists(#new_state)',
ExpressionAttributeNames={'#new_state': 'foobar', '#ttl': 'ttl'},
ExpressionAttributeValues={
':new_state': {'S': 'some-value'},
':ttl': {'N': '12345.67'},
},
ReturnValues='ALL_NEW',
)
update_if_attr_doesnt_exist()
# Second time should fail
with assert_raises(client.exceptions.ConditionalCheckFailedException):
update_if_attr_doesnt_exist()
@mock_dynamodb2 @mock_dynamodb2
def test_query_gsi_with_range_key(): def test_query_gsi_with_range_key():

View File

@ -34,6 +34,39 @@ def test_create_database():
db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False) db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False)
db_instance['DbiResourceId'].should.contain("db-") db_instance['DbiResourceId'].should.contain("db-")
db_instance['CopyTagsToSnapshot'].should.equal(False) db_instance['CopyTagsToSnapshot'].should.equal(False)
db_instance['InstanceCreateTime'].should.be.a("datetime.datetime")
@mock_rds2
def test_create_database_non_existing_option_group():
conn = boto3.client('rds', region_name='us-west-2')
database = conn.create_db_instance.when.called_with(
DBInstanceIdentifier='db-master-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
OptionGroupName='non-existing').should.throw(ClientError)
@mock_rds2
def test_create_database_with_option_group():
conn = boto3.client('rds', region_name='us-west-2')
conn.create_option_group(OptionGroupName='my-og',
EngineName='mysql',
MajorEngineVersion='5.6',
OptionGroupDescription='test option group')
database = conn.create_db_instance(DBInstanceIdentifier='db-master-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
OptionGroupName='my-og')
db_instance = database['DBInstance']
db_instance['AllocatedStorage'].should.equal(10)
db_instance['DBInstanceClass'].should.equal('db.m1.small')
db_instance['DBName'].should.equal('staging-postgres')
db_instance['OptionGroupMemberships'][0]['OptionGroupName'].should.equal('my-og')
@mock_rds2 @mock_rds2
@ -204,6 +237,7 @@ def test_get_databases_paginated():
resp3 = conn.describe_db_instances(MaxRecords=100) resp3 = conn.describe_db_instances(MaxRecords=100)
resp3["DBInstances"].should.have.length_of(51) resp3["DBInstances"].should.have.length_of(51)
@mock_rds2 @mock_rds2
def test_describe_non_existant_database(): def test_describe_non_existant_database():
conn = boto3.client('rds', region_name='us-west-2') conn = boto3.client('rds', region_name='us-west-2')

View File

@ -173,14 +173,16 @@ def test_alias_rrset():
changes.commit() changes.commit()
rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets = conn.get_all_rrsets(zoneid, type="A")
rrset_records = [(rr_set.name, rr) for rr_set in rrsets for rr in rr_set.resource_records] alias_targets = [rr_set.alias_dns_name for rr_set in rrsets]
rrset_records.should.have.length_of(2) alias_targets.should.have.length_of(2)
rrset_records.should.contain(('foo.alias.testdns.aws.com.', 'foo.testdns.aws.com')) alias_targets.should.contain('foo.testdns.aws.com')
rrset_records.should.contain(('bar.alias.testdns.aws.com.', 'bar.testdns.aws.com')) alias_targets.should.contain('bar.testdns.aws.com')
rrsets[0].resource_records[0].should.equal('foo.testdns.aws.com') rrsets[0].alias_dns_name.should.equal('foo.testdns.aws.com')
rrsets[0].resource_records.should.have.length_of(0)
rrsets = conn.get_all_rrsets(zoneid, type="CNAME") rrsets = conn.get_all_rrsets(zoneid, type="CNAME")
rrsets.should.have.length_of(1) rrsets.should.have.length_of(1)
rrsets[0].resource_records[0].should.equal('bar.testdns.aws.com') rrsets[0].alias_dns_name.should.equal('bar.testdns.aws.com')
rrsets[0].resource_records.should.have.length_of(0)
@mock_route53_deprecated @mock_route53_deprecated
@ -583,6 +585,39 @@ def test_change_resource_record_sets_crud_valid():
cname_record_detail['TTL'].should.equal(60) cname_record_detail['TTL'].should.equal(60)
cname_record_detail['ResourceRecords'].should.equal([{'Value': '192.168.1.1'}]) cname_record_detail['ResourceRecords'].should.equal([{'Value': '192.168.1.1'}])
# Update to add Alias.
cname_alias_record_endpoint_payload = {
'Comment': 'Update to Alias prod.redis.db',
'Changes': [
{
'Action': 'UPSERT',
'ResourceRecordSet': {
'Name': 'prod.redis.db.',
'Type': 'A',
'TTL': 60,
'AliasTarget': {
'HostedZoneId': hosted_zone_id,
'DNSName': 'prod.redis.alias.',
'EvaluateTargetHealth': False,
}
}
}
]
}
conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=cname_alias_record_endpoint_payload)
response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id)
cname_alias_record_detail = response['ResourceRecordSets'][0]
cname_alias_record_detail['Name'].should.equal('prod.redis.db.')
cname_alias_record_detail['Type'].should.equal('A')
cname_alias_record_detail['TTL'].should.equal(60)
cname_alias_record_detail['AliasTarget'].should.equal({
'HostedZoneId': hosted_zone_id,
'DNSName': 'prod.redis.alias.',
'EvaluateTargetHealth': False,
})
cname_alias_record_detail.should_not.contain('ResourceRecords')
# Delete record with wrong type. # Delete record with wrong type.
delete_payload = { delete_payload = {
'Comment': 'delete prod.redis.db', 'Comment': 'delete prod.redis.db',

View File

@ -1596,6 +1596,28 @@ def test_boto3_delete_versioned_bucket():
client.delete_bucket(Bucket='blah') client.delete_bucket(Bucket='blah')
@mock_s3
def test_boto3_get_object_if_modified_since():
s3 = boto3.client('s3', region_name='us-east-1')
bucket_name = "blah"
s3.create_bucket(Bucket=bucket_name)
key = 'hello.txt'
s3.put_object(
Bucket=bucket_name,
Key=key,
Body='test'
)
with assert_raises(botocore.exceptions.ClientError) as err:
s3.get_object(
Bucket=bucket_name,
Key=key,
IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1)
)
e = err.exception
e.response['Error'].should.equal({'Code': '304', 'Message': 'Not Modified'})
@mock_s3 @mock_s3
def test_boto3_head_object_if_modified_since(): def test_boto3_head_object_if_modified_since():

View File

@ -0,0 +1,114 @@
from __future__ import unicode_literals
import boto3
import json
from botocore.exceptions import ClientError
from six.moves.email_mime_multipart import MIMEMultipart
from six.moves.email_mime_text import MIMEText
import sure # noqa
from nose import tools
from moto import mock_ses, mock_sns, mock_sqs
from moto.ses.models import SESFeedback
@mock_ses
def test_enable_disable_ses_sns_communication():
conn = boto3.client('ses', region_name='us-east-1')
conn.set_identity_notification_topic(
Identity='test.com',
NotificationType='Bounce',
SnsTopic='the-arn'
)
conn.set_identity_notification_topic(
Identity='test.com',
NotificationType='Bounce'
)
def __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region, expected_msg):
"""Setup the AWS environment to test the SES SNS Feedback"""
# Environment setup
# Create SQS queue
sqs_conn.create_queue(QueueName=queue)
# Create SNS topic
create_topic_response = sns_conn.create_topic(Name=topic)
topic_arn = create_topic_response["TopicArn"]
# Subscribe the SNS topic to the SQS queue
sns_conn.subscribe(TopicArn=topic_arn,
Protocol="sqs",
Endpoint="arn:aws:sqs:%s:123456789012:%s" % (region, queue))
# Verify SES domain
ses_conn.verify_domain_identity(Domain=domain)
# Setup SES notification topic
if expected_msg is not None:
ses_conn.set_identity_notification_topic(
Identity=domain,
NotificationType=expected_msg,
SnsTopic=topic_arn
)
def __test_sns_feedback__(addr, expected_msg):
region_name = "us-east-1"
ses_conn = boto3.client('ses', region_name=region_name)
sns_conn = boto3.client('sns', region_name=region_name)
sqs_conn = boto3.resource('sqs', region_name=region_name)
domain = "example.com"
topic = "bounce-arn-feedback"
queue = "feedback-test-queue"
__setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region_name, expected_msg)
# Send the message
kwargs = dict(
Source="test@" + domain,
Destination={
"ToAddresses": [addr + "@" + domain],
"CcAddresses": ["test_cc@" + domain],
"BccAddresses": ["test_bcc@" + domain],
},
Message={
"Subject": {"Data": "test subject"},
"Body": {"Text": {"Data": "test body"}}
}
)
ses_conn.send_email(**kwargs)
# Wait for messages in the queues
queue = sqs_conn.get_queue_by_name(QueueName=queue)
messages = queue.receive_messages(MaxNumberOfMessages=1)
if expected_msg is not None:
msg = messages[0].body
msg = json.loads(msg)
assert msg["Message"] == SESFeedback.generate_message(expected_msg)
else:
assert len(messages) == 0
@mock_sqs
@mock_sns
@mock_ses
def test_no_sns_feedback():
__test_sns_feedback__("test", None)
@mock_sqs
@mock_sns
@mock_ses
def test_sns_feedback_bounce():
__test_sns_feedback__(SESFeedback.BOUNCE_ADDR, SESFeedback.BOUNCE)
@mock_sqs
@mock_sns
@mock_ses
def test_sns_feedback_complaint():
__test_sns_feedback__(SESFeedback.COMPLAINT_ADDR, SESFeedback.COMPLAINT)
@mock_sqs
@mock_sns
@mock_ses
def test_sns_feedback_delivery():
__test_sns_feedback__(SESFeedback.SUCCESS_ADDR, SESFeedback.DELIVERY)