diff --git a/moto/rds2/exceptions.py b/moto/rds2/exceptions.py index b6dc5bb99..0232f11c7 100644 --- a/moto/rds2/exceptions.py +++ b/moto/rds2/exceptions.py @@ -29,10 +29,10 @@ class DBInstanceNotFoundError(RDSClientError): class DBSnapshotNotFoundError(RDSClientError): - def __init__(self): + def __init__(self, snapshot_identifier): super(DBSnapshotNotFoundError, self).__init__( "DBSnapshotNotFound", - "DBSnapshotIdentifier does not refer to an existing DB snapshot.", + "DBSnapshot {} not found.".format(snapshot_identifier), ) @@ -107,3 +107,15 @@ class DBSnapshotAlreadyExistsError(RDSClientError): database_snapshot_identifier ), ) + + +class InvalidParameterValue(RDSClientError): + def __init__(self, message): + super(InvalidParameterValue, self).__init__("InvalidParameterValue", message) + + +class InvalidParameterCombination(RDSClientError): + def __init__(self, message): + super(InvalidParameterCombination, self).__init__( + "InvalidParameterCombination", message + ) diff --git a/moto/rds2/models.py b/moto/rds2/models.py index eb4159025..4651f68d4 100644 --- a/moto/rds2/models.py +++ b/moto/rds2/models.py @@ -25,10 +25,24 @@ from .exceptions import ( InvalidDBInstanceStateError, SnapshotQuotaExceededError, DBSnapshotAlreadyExistsError, + InvalidParameterValue, + InvalidParameterCombination, ) +from .utils import FilterDef, apply_filter, merge_filters, validate_filters class Database(CloudFormationModel): + + SUPPORTED_FILTERS = { + "db-cluster-id": FilterDef(None, "DB Cluster Identifiers"), + "db-instance-id": FilterDef( + ["db_instance_arn", "db_instance_identifier"], "DB Instance Identifiers" + ), + "dbi-resource-id": FilterDef(["dbi_resource_id"], "Dbi Resource Ids"), + "domain": FilterDef(None, ""), + "engine": FilterDef(["engine"], "Engine Names"), + } + def __init__(self, **kwargs): self.status = "available" self.is_replica = False @@ -517,6 +531,18 @@ class Database(CloudFormationModel): class Snapshot(BaseModel): + + SUPPORTED_FILTERS = { + "db-instance-id": FilterDef( + ["database.db_instance_arn", "database.db_instance_identifier"], + "DB Instance Identifiers", + ), + "db-snapshot-id": FilterDef(["snapshot_id"], "DB Snapshot Identifiers"), + "dbi-resource-id": FilterDef(["database.dbi_resource_id"], "Dbi Resource Ids"), + "snapshot-type": FilterDef(None, "Snapshot Types"), + "engine": FilterDef(["database.engine"], "Engine Names"), + } + def __init__(self, database, snapshot_id, tags): self.database = database self.snapshot_id = snapshot_id @@ -534,6 +560,7 @@ class Snapshot(BaseModel): """ {{ snapshot.snapshot_id }} {{ database.db_instance_identifier }} + {{ database.dbi_resource_id }} {{ snapshot.created_at }} {{ database.engine }} {{ database.allocated_storage }} @@ -839,7 +866,7 @@ class RDS2Backend(BaseBackend): def delete_snapshot(self, db_snapshot_identifier): if db_snapshot_identifier not in self.snapshots: - raise DBSnapshotNotFoundError() + raise DBSnapshotNotFoundError(db_snapshot_identifier) return self.snapshots.pop(db_snapshot_identifier) @@ -858,28 +885,35 @@ class RDS2Backend(BaseBackend): primary.add_replica(replica) return replica - def describe_databases(self, db_instance_identifier=None): + def describe_databases(self, db_instance_identifier=None, filters=None): + databases = self.databases if db_instance_identifier: - if db_instance_identifier in self.databases: - return [self.databases[db_instance_identifier]] - else: - raise DBInstanceNotFoundError(db_instance_identifier) - return self.databases.values() + filters = merge_filters( + filters, {"db-instance-id": [db_instance_identifier]} + ) + if filters: + databases = self._filter_resources(databases, filters, Database) + if db_instance_identifier and not databases: + raise DBInstanceNotFoundError(db_instance_identifier) + return list(databases.values()) - def describe_snapshots(self, db_instance_identifier, db_snapshot_identifier): + def describe_snapshots( + self, db_instance_identifier, db_snapshot_identifier, filters=None + ): + snapshots = self.snapshots if db_instance_identifier: - db_instance_snapshots = [] - for snapshot in self.snapshots.values(): - if snapshot.database.db_instance_identifier == db_instance_identifier: - db_instance_snapshots.append(snapshot) - return db_instance_snapshots - + filters = merge_filters( + filters, {"db-instance-id": [db_instance_identifier]} + ) if db_snapshot_identifier: - if db_snapshot_identifier in self.snapshots: - return [self.snapshots[db_snapshot_identifier]] - raise DBSnapshotNotFoundError() - - return self.snapshots.values() + filters = merge_filters( + filters, {"db-snapshot-id": [db_snapshot_identifier]} + ) + if filters: + snapshots = self._filter_resources(snapshots, filters, Snapshot) + if db_snapshot_identifier and not snapshots and not db_instance_identifier: + raise DBSnapshotNotFoundError(db_snapshot_identifier) + return list(snapshots.values()) def modify_database(self, db_instance_identifier, db_kwargs): database = self.describe_databases(db_instance_identifier)[0] @@ -1322,6 +1356,18 @@ class RDS2Backend(BaseBackend): "InvalidParameterValue", "Invalid resource name: {0}".format(arn) ) + @staticmethod + def _filter_resources(resources, filters, resource_class): + try: + filter_defs = resource_class.SUPPORTED_FILTERS + validate_filters(filters, filter_defs) + return apply_filter(resources, filters, filter_defs) + except KeyError as e: + # https://stackoverflow.com/questions/24998968/why-does-strkeyerror-add-extra-quotes + raise InvalidParameterValue(e.args[0]) + except ValueError as e: + raise InvalidParameterCombination(str(e)) + class OptionGroup(object): def __init__(self, name, engine_name, major_engine_version, description=None): diff --git a/moto/rds2/responses.py b/moto/rds2/responses.py index b63e9f8b8..32ad65846 100644 --- a/moto/rds2/responses.py +++ b/moto/rds2/responses.py @@ -5,6 +5,7 @@ from moto.core.responses import BaseResponse from moto.ec2.models import ec2_backends from .models import rds2_backends from .exceptions import DBParameterGroupNotFoundError +from .utils import filters_from_querystring class RDS2Response(BaseResponse): @@ -122,7 +123,10 @@ class RDS2Response(BaseResponse): def describe_db_instances(self): db_instance_identifier = self._get_param("DBInstanceIdentifier") - all_instances = list(self.backend.describe_databases(db_instance_identifier)) + filters = filters_from_querystring(self.querystring) + all_instances = list( + self.backend.describe_databases(db_instance_identifier, filters=filters) + ) marker = self._get_param("Marker") all_ids = [instance.db_instance_identifier for instance in all_instances] if marker: @@ -178,8 +182,9 @@ class RDS2Response(BaseResponse): def describe_db_snapshots(self): db_instance_identifier = self._get_param("DBInstanceIdentifier") db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") + filters = filters_from_querystring(self.querystring) snapshots = self.backend.describe_snapshots( - db_instance_identifier, db_snapshot_identifier + db_instance_identifier, db_snapshot_identifier, filters ) template = self.response_template(DESCRIBE_SNAPSHOTS_TEMPLATE) return template.render(snapshots=snapshots) diff --git a/moto/rds2/utils.py b/moto/rds2/utils.py new file mode 100644 index 000000000..594a2daab --- /dev/null +++ b/moto/rds2/utils.py @@ -0,0 +1,169 @@ +from __future__ import unicode_literals + +import re +from collections import namedtuple + +from botocore.utils import merge_dicts + +from moto.compat import OrderedDict + +FilterDef = namedtuple( + "FilterDef", + [ + # A list of object attributes to check against the filter values. + # Set to None if filter is not yet implemented in `moto`. + "attrs_to_check", + # Description of the filter, e.g. 'Object Identifiers'. + # Used in filter error messaging. + "description", + ], +) + + +def filters_from_querystring(querystring): + """Parses filters out of the query string computed by the + moto.core.responses.BaseResponse class. + + :param dict[str, list[str]] querystring: + The `moto`-processed URL query string dictionary. + :returns: + Dict mapping filter names to filter values. + :rtype: + dict[str, list[str]] + """ + response_values = {} + for key, value in sorted(querystring.items()): + match = re.search(r"Filters.Filter.(\d).Name", key) + if match: + filter_index = match.groups()[0] + value_prefix = "Filters.Filter.{0}.Value".format(filter_index) + filter_values = [ + filter_value[0] + for filter_key, filter_value in querystring.items() + if filter_key.startswith(value_prefix) + ] + # The AWS query protocol serializes empty lists as an empty string. + if filter_values == [""]: + filter_values = [] + response_values[value[0]] = filter_values + return response_values + + +def get_object_value(obj, attr): + """Retrieves an arbitrary attribute value from an object. + + Nested attributes can be specified using dot notation, + e.g. 'parent.child'. + + :param object obj: + A valid Python object. + :param str attr: + The attribute name of the value to retrieve from the object. + :returns: + The attribute value, if it exists, or None. + :rtype: + any + """ + keys = attr.split(".") + val = obj + for key in keys: + if hasattr(val, key): + val = getattr(val, key) + else: + return None + return val + + +def merge_filters(filters_to_update, filters_to_merge): + """Given two groups of filters, merge the second into the first. + + List values are appended instead of overwritten: + + >>> merge_filters({'filter-name': ['value1']}, {'filter-name':['value2']}) + >>> {'filter-name': ['value1', 'value2']} + + :param filters_to_update: + The filters to update. + :type filters_to_update: + dict[str, list] or None + :param filters_to_merge: + The filters to merge. + :type filters_to_merge: + dict[str, list] or None + :returns: + The updated filters. + :rtype: + dict[str, list] + """ + if filters_to_update is None: + filters_to_update = {} + if filters_to_merge is None: + filters_to_merge = {} + merge_dicts(filters_to_update, filters_to_merge, append_lists=True) + return filters_to_update + + +def validate_filters(filters, filter_defs): + """Validates filters against a set of filter definitions. + + Raises standard Python exceptions which should be caught + and translated to an appropriate AWS/Moto exception higher + up the call stack. + + :param dict[str, list] filters: + The filters to validate. + :param dict[str, FilterDef] filter_defs: + The filter definitions to validate against. + :returns: None + :rtype: None + :raises KeyError: + if filter name not found in the filter definitions. + :raises ValueError: + if filter values is an empty list. + :raises NotImplementedError: + if `moto` does not yet support this filter. + """ + for filter_name, filter_values in filters.items(): + filter_def = filter_defs.get(filter_name) + if filter_def is None: + raise KeyError("Unrecognized filter name: {}".format(filter_name)) + if not filter_values: + raise ValueError( + "The list of {} must not be empty.".format(filter_def.description) + ) + if filter_def.attrs_to_check is None: + raise NotImplementedError( + "{} filter has not been implemented in Moto yet.".format(filter_name) + ) + + +def apply_filter(resources, filters, filter_defs): + """Apply an arbitrary filter to a group of resources. + + :param dict[str, object] resources: + A dictionary mapping resource identifiers to resource objects. + :param dict[str, list] filters: + The filters to apply. + :param dict[str, FilterDef] filter_defs: + The supported filter definitions for the resource type. + :returns: + The filtered collection of resources. + :rtype: + dict[str, object] + """ + resources_filtered = OrderedDict() + for identifier, obj in resources.items(): + matches_filter = False + for filter_name, filter_values in filters.items(): + filter_def = filter_defs.get(filter_name) + for attr in filter_def.attrs_to_check: + if get_object_value(obj, attr) in filter_values: + matches_filter = True + break + else: + matches_filter = False + if not matches_filter: + break + if matches_filter: + resources_filtered[identifier] = obj + return resources_filtered diff --git a/tests/test_rds2/test_filters.py b/tests/test_rds2/test_filters.py new file mode 100644 index 000000000..a59867d24 --- /dev/null +++ b/tests/test_rds2/test_filters.py @@ -0,0 +1,359 @@ +from __future__ import unicode_literals + +import boto3 +import pytest +import sure # noqa +from botocore.exceptions import ClientError + +from moto import mock_rds2 + + +class TestDBInstanceFilters(object): + + mock_rds = mock_rds2() + + @classmethod + def setup_class(cls): + cls.mock_rds.start() + client = boto3.client("rds", region_name="us-west-2") + for i in range(10): + identifier = "db-instance-{}".format(i) + engine = "postgres" if (i % 3) else "mysql" + client.create_db_instance( + DBInstanceIdentifier=identifier, + Engine=engine, + DBInstanceClass="db.m1.small", + ) + cls.client = client + + @classmethod + def teardown_class(cls): + cls.mock_rds.stop() + + def test_invalid_filter_name_raises_error(self): + with pytest.raises(ClientError) as ex: + self.client.describe_db_instances( + Filters=[{"Name": "invalid-filter-name", "Values": []}] + ) + ex.value.response["Error"]["Code"].should.equal("InvalidParameterValue") + ex.value.response["Error"]["Message"].should.equal( + "Unrecognized filter name: invalid-filter-name" + ) + + def test_empty_filter_values_raises_error(self): + with pytest.raises(ClientError) as ex: + self.client.describe_db_instances( + Filters=[{"Name": "db-instance-id", "Values": []}] + ) + ex.value.response["Error"]["Code"].should.equal("InvalidParameterCombination") + ex.value.response["Error"]["Message"].should.contain("must not be empty") + + def test_db_instance_id_filter(self): + resp = self.client.describe_db_instances() + db_instance_identifier = resp["DBInstances"][0]["DBInstanceIdentifier"] + + db_instances = self.client.describe_db_instances( + Filters=[{"Name": "db-instance-id", "Values": [db_instance_identifier]}] + ).get("DBInstances") + db_instances.should.have.length_of(1) + db_instances[0]["DBInstanceIdentifier"].should.equal(db_instance_identifier) + + def test_db_instance_id_filter_works_with_arns(self): + resp = self.client.describe_db_instances() + db_instance_arn = resp["DBInstances"][0]["DBInstanceArn"] + + db_instances = self.client.describe_db_instances( + Filters=[{"Name": "db-instance-id", "Values": [db_instance_arn]}] + ).get("DBInstances") + db_instances.should.have.length_of(1) + db_instances[0]["DBInstanceArn"].should.equal(db_instance_arn) + + def test_dbi_resource_id_filter(self): + resp = self.client.describe_db_instances() + dbi_resource_identifier = resp["DBInstances"][0]["DbiResourceId"] + + db_instances = self.client.describe_db_instances( + Filters=[{"Name": "dbi-resource-id", "Values": [dbi_resource_identifier]}] + ).get("DBInstances") + for db_instance in db_instances: + db_instance["DbiResourceId"].should.equal(dbi_resource_identifier) + + def test_engine_filter(self): + db_instances = self.client.describe_db_instances( + Filters=[{"Name": "engine", "Values": ["postgres"]}] + ).get("DBInstances") + for db_instance in db_instances: + db_instance["Engine"].should.equal("postgres") + + db_instances = self.client.describe_db_instances( + Filters=[{"Name": "engine", "Values": ["oracle"]}] + ).get("DBInstances") + db_instances.should.have.length_of(0) + + def test_multiple_filters(self): + resp = self.client.describe_db_instances( + Filters=[ + { + "Name": "db-instance-id", + "Values": ["db-instance-0", "db-instance-1", "db-instance-3"], + }, + {"Name": "engine", "Values": ["mysql", "oracle"]}, + ] + ) + returned_identifiers = [ + db["DBInstanceIdentifier"] for db in resp["DBInstances"] + ] + returned_identifiers.should.have.length_of(2) + "db-instance-0".should.be.within(returned_identifiers) + "db-instance-3".should.be.within(returned_identifiers) + + def test_invalid_db_instance_identifier_with_exclusive_filter(self): + # Passing a non-existent DBInstanceIdentifier will not raise an error + # if the resulting filter matches other resources. + resp = self.client.describe_db_instances( + DBInstanceIdentifier="non-existent", + Filters=[{"Name": "db-instance-id", "Values": ["db-instance-1"]}], + ) + resp["DBInstances"].should.have.length_of(1) + resp["DBInstances"][0]["DBInstanceIdentifier"].should.equal("db-instance-1") + + def test_invalid_db_instance_identifier_with_non_matching_filter(self): + # Passing a non-existent DBInstanceIdentifier will raise an error if + # the resulting filter does not match any resources. + with pytest.raises(ClientError) as ex: + self.client.describe_db_instances( + DBInstanceIdentifier="non-existent", + Filters=[{"Name": "engine", "Values": ["mysql"]}], + ) + ex.value.response["Error"]["Code"].should.equal("DBInstanceNotFound") + ex.value.response["Error"]["Message"].should.equal( + "Database non-existent not found." + ) + + def test_valid_db_instance_identifier_with_exclusive_filter(self): + # Passing a valid DBInstanceIdentifier with a filter it does not match + # but does match other resources will return those other resources. + resp = self.client.describe_db_instances( + DBInstanceIdentifier="db-instance-0", + Filters=[ + {"Name": "db-instance-id", "Values": ["db-instance-1"]}, + {"Name": "engine", "Values": ["postgres"]}, + ], + ) + returned_identifiers = [ + db["DBInstanceIdentifier"] for db in resp["DBInstances"] + ] + "db-instance-0".should_not.be.within(returned_identifiers) + "db-instance-1".should.be.within(returned_identifiers) + + def test_valid_db_instance_identifier_with_inclusive_filter(self): + # Passing a valid DBInstanceIdentifier with a filter it matches but also + # matches other resources will return all matching resources. + resp = self.client.describe_db_instances( + DBInstanceIdentifier="db-instance-0", + Filters=[ + {"Name": "db-instance-id", "Values": ["db-instance-1"]}, + {"Name": "engine", "Values": ["mysql", "postgres"]}, + ], + ) + returned_identifiers = [ + db["DBInstanceIdentifier"] for db in resp["DBInstances"] + ] + "db-instance-0".should.be.within(returned_identifiers) + "db-instance-1".should.be.within(returned_identifiers) + + def test_valid_db_instance_identifier_with_non_matching_filter(self): + # Passing a valid DBInstanceIdentifier will raise an error if the + # resulting filter does not match any resources. + with pytest.raises(ClientError) as ex: + self.client.describe_db_instances( + DBInstanceIdentifier="db-instance-0", + Filters=[{"Name": "engine", "Values": ["postgres"]}], + ) + ex.value.response["Error"]["Code"].should.equal("DBInstanceNotFound") + ex.value.response["Error"]["Message"].should.equal( + "Database db-instance-0 not found." + ) + + +class TestDBSnapshotFilters(object): + + mock_rds = mock_rds2() + + @classmethod + def setup_class(cls): + cls.mock_rds.start() + client = boto3.client("rds", region_name="us-west-2") + # We'll set up two instances (one postgres, one mysql) + # with two snapshots each. + for i in range(2): + identifier = "db-instance-{}".format(i) + engine = "postgres" if i else "mysql" + client.create_db_instance( + DBInstanceIdentifier=identifier, + Engine=engine, + DBInstanceClass="db.m1.small", + ) + for j in range(2): + client.create_db_snapshot( + DBInstanceIdentifier=identifier, + DBSnapshotIdentifier="{}-snapshot-{}".format(identifier, j), + ) + cls.client = client + + @classmethod + def teardown_class(cls): + cls.mock_rds.stop() + + def test_invalid_filter_name_raises_error(self): + with pytest.raises(ClientError) as ex: + self.client.describe_db_snapshots( + Filters=[{"Name": "invalid-filter-name", "Values": []}] + ) + ex.value.response["Error"]["Code"].should.equal("InvalidParameterValue") + ex.value.response["Error"]["Message"].should.equal( + "Unrecognized filter name: invalid-filter-name" + ) + + def test_empty_filter_values_raises_error(self): + with pytest.raises(ClientError) as ex: + self.client.describe_db_snapshots( + Filters=[{"Name": "db-snapshot-id", "Values": []}] + ) + ex.value.response["Error"]["Code"].should.equal("InvalidParameterCombination") + ex.value.response["Error"]["Message"].should.contain("must not be empty") + + def test_db_snapshot_id_filter(self): + snapshots = self.client.describe_db_snapshots( + Filters=[{"Name": "db-snapshot-id", "Values": ["db-instance-1-snapshot-0"]}] + ).get("DBSnapshots") + snapshots.should.have.length_of(1) + snapshots[0]["DBSnapshotIdentifier"].should.equal("db-instance-1-snapshot-0") + + def test_db_instance_id_filter(self): + resp = self.client.describe_db_instances() + db_instance_identifier = resp["DBInstances"][0]["DBInstanceIdentifier"] + + snapshots = self.client.describe_db_snapshots( + Filters=[{"Name": "db-instance-id", "Values": [db_instance_identifier]}] + ).get("DBSnapshots") + for snapshot in snapshots: + snapshot["DBInstanceIdentifier"].should.equal(db_instance_identifier) + + def test_db_instance_id_filter_works_with_arns(self): + resp = self.client.describe_db_instances() + db_instance_identifier = resp["DBInstances"][0]["DBInstanceIdentifier"] + db_instance_arn = resp["DBInstances"][0]["DBInstanceArn"] + + snapshots = self.client.describe_db_snapshots( + Filters=[{"Name": "db-instance-id", "Values": [db_instance_arn]}] + ).get("DBSnapshots") + for snapshot in snapshots: + snapshot["DBInstanceIdentifier"].should.equal(db_instance_identifier) + + def test_dbi_resource_id_filter(self): + resp = self.client.describe_db_instances() + dbi_resource_identifier = resp["DBInstances"][0]["DbiResourceId"] + + snapshots = self.client.describe_db_snapshots( + Filters=[{"Name": "dbi-resource-id", "Values": [dbi_resource_identifier]}] + ).get("DBSnapshots") + for snapshot in snapshots: + snapshot["DbiResourceId"].should.equal(dbi_resource_identifier) + + def test_engine_filter(self): + snapshots = self.client.describe_db_snapshots( + Filters=[{"Name": "engine", "Values": ["postgres"]}] + ).get("DBSnapshots") + for snapshot in snapshots: + snapshot["Engine"].should.equal("postgres") + + snapshots = self.client.describe_db_snapshots( + Filters=[{"Name": "engine", "Values": ["oracle"]}] + ).get("DBSnapshots") + snapshots.should.have.length_of(0) + + def test_multiple_filters(self): + snapshots = self.client.describe_db_snapshots( + Filters=[ + {"Name": "db-snapshot-id", "Values": ["db-instance-0-snapshot-1"]}, + { + "Name": "db-instance-id", + "Values": ["db-instance-1", "db-instance-0"], + }, + {"Name": "engine", "Values": ["mysql"]}, + ] + ).get("DBSnapshots") + snapshots.should.have.length_of(1) + snapshots[0]["DBSnapshotIdentifier"].should.equal("db-instance-0-snapshot-1") + + def test_invalid_snapshot_id_with_db_instance_id_and_filter(self): + # Passing a non-existent DBSnapshotIdentifier will return an empty list + # if DBInstanceIdentifier is also passed in. + resp = self.client.describe_db_snapshots( + DBSnapshotIdentifier="non-existent", + DBInstanceIdentifier="a-db-instance-identifier", + Filters=[{"Name": "db-instance-id", "Values": ["db-instance-1"]}], + ) + resp["DBSnapshots"].should.have.length_of(0) + + def test_invalid_snapshot_id_with_non_matching_filter(self): + # Passing a non-existent DBSnapshotIdentifier will raise an error if + # the resulting filter does not match any resources. + with pytest.raises(ClientError) as ex: + self.client.describe_db_snapshots( + DBSnapshotIdentifier="non-existent", + Filters=[{"Name": "engine", "Values": ["oracle"]}], + ) + ex.value.response["Error"]["Code"].should.equal("DBSnapshotNotFound") + ex.value.response["Error"]["Message"].should.equal( + "DBSnapshot non-existent not found." + ) + + def test_valid_snapshot_id_with_exclusive_filter(self): + # Passing a valid DBSnapshotIdentifier with a filter it does not match + # but does match other resources will return those other resources. + resp = self.client.describe_db_snapshots( + DBSnapshotIdentifier="db-instance-0-snapshot-0", + Filters=[ + {"Name": "db-snapshot-id", "Values": ["db-instance-1-snapshot-1"]}, + {"Name": "db-instance-id", "Values": ["db-instance-1"]}, + {"Name": "engine", "Values": ["postgres"]}, + ], + ) + resp["DBSnapshots"].should.have.length_of(1) + resp["DBSnapshots"][0]["DBSnapshotIdentifier"].should.equal( + "db-instance-1-snapshot-1" + ) + + def test_valid_snapshot_id_with_inclusive_filter(self): + # Passing a valid DBSnapshotIdentifier with a filter it matches but also + # matches other resources will return all matching resources. + snapshots = self.client.describe_db_snapshots( + DBSnapshotIdentifier="db-instance-0-snapshot-0", + Filters=[ + {"Name": "db-snapshot-id", "Values": ["db-instance-1-snapshot-1"]}, + { + "Name": "db-instance-id", + "Values": ["db-instance-1", "db-instance-0"], + }, + {"Name": "engine", "Values": ["mysql", "postgres"]}, + ], + ).get("DBSnapshots") + returned_identifiers = [ss["DBSnapshotIdentifier"] for ss in snapshots] + returned_identifiers.should.have.length_of(2) + "db-instance-0-snapshot-0".should.be.within(returned_identifiers) + "db-instance-1-snapshot-1".should.be.within(returned_identifiers) + + def test_valid_snapshot_id_with_non_matching_filter(self): + # Passing a valid DBSnapshotIdentifier will raise an error if the + # resulting filter does not match any resources. + with pytest.raises(ClientError) as ex: + self.client.describe_db_snapshots( + DBSnapshotIdentifier="db-instance-0-snapshot-0", + Filters=[{"Name": "engine", "Values": ["postgres"]}], + ) + ex.value.response["Error"]["Code"].should.equal("DBSnapshotNotFound") + ex.value.response["Error"]["Message"].should.equal( + "DBSnapshot db-instance-0-snapshot-0 not found." + ) diff --git a/tests/test_rds2/test_utils.py b/tests/test_rds2/test_utils.py new file mode 100644 index 000000000..5cb51a0d5 --- /dev/null +++ b/tests/test_rds2/test_utils.py @@ -0,0 +1,174 @@ +from __future__ import unicode_literals + +import pytest + +from moto.rds2.utils import ( + FilterDef, + apply_filter, + merge_filters, + filters_from_querystring, + validate_filters, +) + + +class TestFilterValidation(object): + @classmethod + def setup_class(cls): + cls.filter_defs = { + "not-implemented": FilterDef(None, ""), + "identifier": FilterDef(["identifier"], "Object Identifiers"), + } + + def test_unrecognized_filter_raises_exception(self): + filters = {"invalid-filter-name": ["test-value"]} + with pytest.raises(KeyError) as ex: + validate_filters(filters, self.filter_defs) + assert "Unrecognized filter name: invalid-filter-name" in str(ex) + + def test_empty_filter_values_raises_exception(self): + filters = {"identifier": []} + with pytest.raises(ValueError) as ex: + validate_filters(filters, self.filter_defs) + assert "Object Identifiers must not be empty" in str(ex) + + def test_unimplemented_filter_raises_exception(self): + filters = {"not-implemented": ["test-value"]} + with pytest.raises(NotImplementedError): + validate_filters(filters, self.filter_defs) + + +class Resource(object): + def __init__(self, identifier, **kwargs): + self.identifier = identifier + self.__dict__.update(kwargs) + + +class TestResourceFiltering(object): + @classmethod + def setup_class(cls): + cls.filter_defs = { + "identifier": FilterDef(["identifier"], "Object Identifiers"), + "nested-resource": FilterDef(["nested.identifier"], "Nested Identifiers"), + "common-attr": FilterDef(["common_attr"], ""), + "multiple-attrs": FilterDef(["common_attr", "uncommon_attr"], ""), + } + cls.resources = { + "identifier-0": Resource("identifier-0"), + "identifier-1": Resource("identifier-1", common_attr="common"), + "identifier-2": Resource("identifier-2"), + "identifier-3": Resource("identifier-3", nested=Resource("nested-id-1")), + "identifier-4": Resource("identifier-4", common_attr="common"), + "identifier-5": Resource("identifier-5", uncommon_attr="common"), + } + + def test_filtering_on_nested_attribute(self): + filters = {"nested-resource": ["nested-id-1"]} + filtered_resources = apply_filter(self.resources, filters, self.filter_defs) + filtered_resources.should.have.length_of(1) + filtered_resources.should.have.key("identifier-3") + + def test_filtering_on_common_attribute(self): + filters = {"common-attr": ["common"]} + filtered_resources = apply_filter(self.resources, filters, self.filter_defs) + filtered_resources.should.have.length_of(2) + filtered_resources.should.have.key("identifier-1") + filtered_resources.should.have.key("identifier-4") + + def test_filtering_on_multiple_attributes(self): + filters = {"multiple-attrs": ["common"]} + filtered_resources = apply_filter(self.resources, filters, self.filter_defs) + filtered_resources.should.have.length_of(3) + filtered_resources.should.have.key("identifier-1") + filtered_resources.should.have.key("identifier-4") + filtered_resources.should.have.key("identifier-5") + + def test_filters_with_multiple_values(self): + filters = {"identifier": ["identifier-0", "identifier-3", "identifier-5"]} + filtered_resources = apply_filter(self.resources, filters, self.filter_defs) + filtered_resources.should.have.length_of(3) + filtered_resources.should.have.key("identifier-0") + filtered_resources.should.have.key("identifier-3") + filtered_resources.should.have.key("identifier-5") + + def test_multiple_filters(self): + filters = { + "identifier": ["identifier-1", "identifier-3", "identifier-5"], + "common-attr": ["common"], + "multiple-attrs": ["common"], + } + filtered_resources = apply_filter(self.resources, filters, self.filter_defs) + filtered_resources.should.have.length_of(1) + filtered_resources.should.have.key("identifier-1") + + +class TestMergingFilters(object): + def test_when_filters_to_update_is_none(self): + filters_to_update = {"filter-name": ["value1"]} + merged = merge_filters(filters_to_update, None) + assert merged == filters_to_update + + def test_when_filters_to_merge_is_none(self): + filters_to_merge = {"filter-name": ["value1"]} + merged = merge_filters(None, filters_to_merge) + assert merged == filters_to_merge + + def test_when_both_filters_are_none(self): + merged = merge_filters(None, None) + assert merged == {} + + def test_values_are_merged(self): + filters_to_update = {"filter-name": ["value1"]} + filters_to_merge = {"filter-name": ["value2"]} + merged = merge_filters(filters_to_update, filters_to_merge) + assert merged == {"filter-name": ["value1", "value2"]} + + def test_complex_merge(self): + filters_to_update = { + "filter-name-1": ["value1"], + "filter-name-2": ["value1", "value2"], + "filter-name-3": ["value1"], + } + filters_to_merge = { + "filter-name-1": ["value2"], + "filter-name-3": ["value2"], + "filter-name-4": ["value1", "value2"], + } + merged = merge_filters(filters_to_update, filters_to_merge) + assert len(merged.keys()) == 4 + for key in merged.keys(): + assert merged[key] == ["value1", "value2"] + + +class TestParsingFiltersFromQuerystring(object): + def test_parse_empty_list(self): + # The AWS query protocol serializes empty lists as an empty string. + querystring = { + "Filters.Filter.1.Name": ["empty-filter"], + "Filters.Filter.1.Value.1": [""], + } + filters = filters_from_querystring(querystring) + assert filters == {"empty-filter": []} + + def test_multiple_values(self): + querystring = { + "Filters.Filter.1.Name": ["multi-value"], + "Filters.Filter.1.Value.1": ["value1"], + "Filters.Filter.1.Value.2": ["value2"], + } + filters = filters_from_querystring(querystring) + values = filters["multi-value"] + assert len(values) == 2 + assert "value1" in values + assert "value2" in values + + def test_multiple_filters(self): + querystring = { + "Filters.Filter.1.Name": ["filter-1"], + "Filters.Filter.1.Value.1": ["value1"], + "Filters.Filter.2.Name": ["filter-2"], + "Filters.Filter.2.Value.1": ["value2"], + } + filters = filters_from_querystring(querystring) + assert len(filters.keys()) == 2 + assert filters["filter-1"] == ["value1"] + assert filters["filter-2"] == ["value2"]