Add support for filters for list_resolver_endpoints (#4598)

This commit is contained in:
kbalk 2021-11-19 08:57:54 -05:00 committed by GitHub
parent 32b2a90ee6
commit be197caba6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 192 additions and 5 deletions

View File

@ -2,6 +2,7 @@
from collections import defaultdict
from datetime import datetime, timezone
from ipaddress import ip_address, ip_network
import re
from boto3 import Session
@ -25,6 +26,8 @@ from moto.route53resolver.validations import validate_args
from moto.utilities.paginator import paginate
from moto.utilities.tagging_service import TaggingService
CAMEL_TO_SNAKE_PATTERN = re.compile(r"(?<!^)(?=[A-Z])")
class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attributes
"""Representation of a fake Route53 Resolver Endpoint."""
@ -32,6 +35,18 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
MAX_TAGS_PER_RESOLVER_ENDPOINT = 200
MAX_ENDPOINTS_PER_REGION = 4
# There are two styles of filter names and either will be transformed
# into lowercase snake.
FILTER_NAMES = [
"creator_request_id",
"direction",
"host_vpc_id",
"ip_address_count",
"name",
"security_group_ids",
"status",
]
def __init__(
self,
region,
@ -328,14 +343,71 @@ class Route53ResolverBackend(BaseBackend):
endpoint = self.resolver_endpoints[resolver_endpoint_id]
return endpoint.ip_descriptions()
@staticmethod
def _add_field_name_to_filter(filters):
"""Convert both styles of filter names to lowercase snake format.
"IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count".
However, "HostVPCId" doesn't fit the pattern, so that's treated
special.
"""
for rr_filter in filters:
filter_name = rr_filter["Name"]
if "_" not in filter_name:
if filter_name == "HostVPCId":
filter_name = "host_vpc_id"
elif filter_name == "HostVpcId":
filter_name = "WRONG"
elif not filter_name.isupper():
filter_name = CAMEL_TO_SNAKE_PATTERN.sub("_", filter_name)
rr_filter["Field"] = filter_name.lower()
@staticmethod
def _validate_filters(filters, allowed_filter_names):
"""Raise exception if filter names are not as expected."""
for rr_filter in filters:
if rr_filter["Field"] not in allowed_filter_names:
raise InvalidParameterException(
f"The filter '{rr_filter['Name']}' is invalid"
)
if "Values" not in rr_filter:
raise InvalidParameterException(
f"No values specified for filter {rr_filter['Name']}"
)
@staticmethod
def _matches_all_filters(entity, filters):
"""Return True if this entity has fields matching all the filters."""
for rr_filter in filters:
field_value = getattr(entity, rr_filter["Field"])
if isinstance(field_value, list):
if not set(field_value).intersection(rr_filter["Values"]):
return False
elif isinstance(field_value, int):
if str(field_value) not in rr_filter["Values"]:
return False
elif field_value not in rr_filter["Values"]:
return False
return True
@paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_endpoints(
self, filters=None, next_token=None, max_results=None,
self, filters, next_token=None, max_results=None,
): # pylint: disable=unused-argument
"""List all resolver endpoints, using filters if specified."""
# TODO - check subsequent filters
# TODO - validate name, values for filters
return sorted(self.resolver_endpoints.values(), key=lambda x: x.name)
if not filters:
filters = []
self._add_field_name_to_filter(filters)
self._validate_filters(filters, ResolverEndpoint.FILTER_NAMES)
endpoints = []
for endpoint in sorted(self.resolver_endpoints.values(), key=lambda x: x.name):
if self._matches_all_filters(endpoint, filters):
endpoints.append(endpoint)
return endpoints
@paginate(pagination_model=PAGINATION_MODEL)
def list_tags_for_resource(

View File

@ -88,7 +88,7 @@ class Route53ResolverResponse(BaseResponse):
endpoints,
next_token,
) = self.route53resolver_backend.list_resolver_endpoints(
filters=filters, next_token=next_token, max_results=max_results
filters, next_token=next_token, max_results=max_results
)
except InvalidToken as exc:
raise InvalidNextTokenException() from exc

View File

@ -670,6 +670,121 @@ def test_route53resolver_list_resolver_endpoints():
assert endpoint["Name"].startswith(f"A{idx + 1}")
@mock_ec2
@mock_route53resolver
def test_route53resolver_list_resolver_endpoints_filters():
"""Test good list_resolver_endpoint API calls that use filters."""
client = boto3.client("route53resolver", region_name=TEST_REGION)
ec2_client = boto3.client("ec2", region_name=TEST_REGION)
random_num = get_random_hex(10)
# Create some endpoints for testing purposes
security_group_id = create_security_group(ec2_client)
vpc_id = create_vpc(ec2_client)
subnet_ids = create_subnets(ec2_client, vpc_id)
ip0_values = ["10.0.1.201", "10.0.1.202", "10.0.1.203", "10.0.1.204"]
ip1_values = ["10.0.0.21", "10.0.0.22", "10.0.0.23", "10.0.0.24"]
endpoints = []
for idx in range(1, 5):
ip_addrs = [
{"SubnetId": subnet_ids[0], "Ip": "10.0.1.200"},
{"SubnetId": subnet_ids[1], "Ip": "10.0.0.20"},
{"SubnetId": subnet_ids[0], "Ip": ip0_values[idx - 1]},
{"SubnetId": subnet_ids[1], "Ip": ip1_values[idx - 1]},
]
response = client.create_resolver_endpoint(
CreatorRequestId=f"F{idx}-{random_num}",
Name=f"F{idx}-{random_num}",
SecurityGroupIds=[security_group_id],
Direction="INBOUND" if idx % 2 else "OUTBOUND",
IpAddresses=ip_addrs,
)
endpoints.append(response["ResolverEndpoint"])
# Try all the valid filter names, including some of the old style names.
response = client.list_resolver_endpoints(
Filters=[{"Name": "CreatorRequestId", "Values": [f"F3-{random_num}"]}]
)
assert len(response["ResolverEndpoints"]) == 1
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F3-{random_num}"
response = client.list_resolver_endpoints(
Filters=[
{
"Name": "CREATOR_REQUEST_ID",
"Values": [f"F2-{random_num}", f"F4-{random_num}"],
}
]
)
assert len(response["ResolverEndpoints"]) == 2
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F2-{random_num}"
assert response["ResolverEndpoints"][1]["CreatorRequestId"] == f"F4-{random_num}"
response = client.list_resolver_endpoints(
Filters=[{"Name": "Direction", "Values": ["INBOUND"]}]
)
assert len(response["ResolverEndpoints"]) == 2
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F1-{random_num}"
assert response["ResolverEndpoints"][1]["CreatorRequestId"] == f"F3-{random_num}"
response = client.list_resolver_endpoints(
Filters=[{"Name": "HostVPCId", "Values": [vpc_id]}]
)
assert len(response["ResolverEndpoints"]) == 4
response = client.list_resolver_endpoints(
Filters=[{"Name": "IpAddressCount", "Values": ["4"]}]
)
assert len(response["ResolverEndpoints"]) == 4
response = client.list_resolver_endpoints(
Filters=[{"Name": "Name", "Values": [f"F1-{random_num}"]}]
)
assert len(response["ResolverEndpoints"]) == 1
assert response["ResolverEndpoints"][0]["Name"] == f"F1-{random_num}"
response = client.list_resolver_endpoints(
Filters=[
{"Name": "HOST_VPC_ID", "Values": [vpc_id]},
{"Name": "DIRECTION", "Values": ["INBOUND"]},
{"Name": "NAME", "Values": [f"F3-{random_num}"]},
]
)
assert len(response["ResolverEndpoints"]) == 1
assert response["ResolverEndpoints"][0]["Name"] == f"F3-{random_num}"
response = client.list_resolver_endpoints(
Filters=[{"Name": "SecurityGroupIds", "Values": [security_group_id]}]
)
assert len(response["ResolverEndpoints"]) == 4
response = client.list_resolver_endpoints(
Filters=[{"Name": "Status", "Values": ["OPERATIONAL"]}]
)
assert len(response["ResolverEndpoints"]) == 4
response = client.list_resolver_endpoints(
Filters=[{"Name": "Status", "Values": ["CREATING"]}]
)
assert len(response["ResolverEndpoints"]) == 0
@mock_route53resolver
def test_route53resolver_bad_list_resolver_endpoints_filters():
"""Test bad list_resolver_endpoint API calls that use filters."""
client = boto3.client("route53resolver", region_name=TEST_REGION)
# botocore barfs on an empty "Values":
# TypeError: list_resolver_endpoints() only accepts keyword arguments.
# client.list_resolver_endpoints([{"Name": "Direction", "Values": []}])
# client.list_resolver_endpoints([{"Values": []}])
with pytest.raises(ClientError) as exc:
client.list_resolver_endpoints(Filters=[{"Name": "foo", "Values": ["bar"]}])
err = exc.value.response["Error"]
assert err["Code"] == "InvalidParameterException"
assert "The filter 'foo' is invalid" in err["Message"]
@mock_ec2
@mock_route53resolver
def test_route53resolver_bad_list_resolver_endpoints():