Add support for filters for list_resolver_endpoints (#4598)
This commit is contained in:
parent
32b2a90ee6
commit
be197caba6
@ -2,6 +2,7 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from ipaddress import ip_address, ip_network
|
||||
import re
|
||||
|
||||
from boto3 import Session
|
||||
|
||||
@ -25,6 +26,8 @@ from moto.route53resolver.validations import validate_args
|
||||
from moto.utilities.paginator import paginate
|
||||
from moto.utilities.tagging_service import TaggingService
|
||||
|
||||
CAMEL_TO_SNAKE_PATTERN = re.compile(r"(?<!^)(?=[A-Z])")
|
||||
|
||||
|
||||
class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attributes
|
||||
"""Representation of a fake Route53 Resolver Endpoint."""
|
||||
@ -32,6 +35,18 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
|
||||
MAX_TAGS_PER_RESOLVER_ENDPOINT = 200
|
||||
MAX_ENDPOINTS_PER_REGION = 4
|
||||
|
||||
# There are two styles of filter names and either will be transformed
|
||||
# into lowercase snake.
|
||||
FILTER_NAMES = [
|
||||
"creator_request_id",
|
||||
"direction",
|
||||
"host_vpc_id",
|
||||
"ip_address_count",
|
||||
"name",
|
||||
"security_group_ids",
|
||||
"status",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
region,
|
||||
@ -328,14 +343,71 @@ class Route53ResolverBackend(BaseBackend):
|
||||
endpoint = self.resolver_endpoints[resolver_endpoint_id]
|
||||
return endpoint.ip_descriptions()
|
||||
|
||||
@staticmethod
|
||||
def _add_field_name_to_filter(filters):
|
||||
"""Convert both styles of filter names to lowercase snake format.
|
||||
|
||||
"IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count".
|
||||
However, "HostVPCId" doesn't fit the pattern, so that's treated
|
||||
special.
|
||||
"""
|
||||
for rr_filter in filters:
|
||||
filter_name = rr_filter["Name"]
|
||||
if "_" not in filter_name:
|
||||
if filter_name == "HostVPCId":
|
||||
filter_name = "host_vpc_id"
|
||||
elif filter_name == "HostVpcId":
|
||||
filter_name = "WRONG"
|
||||
elif not filter_name.isupper():
|
||||
filter_name = CAMEL_TO_SNAKE_PATTERN.sub("_", filter_name)
|
||||
rr_filter["Field"] = filter_name.lower()
|
||||
|
||||
@staticmethod
|
||||
def _validate_filters(filters, allowed_filter_names):
|
||||
"""Raise exception if filter names are not as expected."""
|
||||
for rr_filter in filters:
|
||||
if rr_filter["Field"] not in allowed_filter_names:
|
||||
raise InvalidParameterException(
|
||||
f"The filter '{rr_filter['Name']}' is invalid"
|
||||
)
|
||||
if "Values" not in rr_filter:
|
||||
raise InvalidParameterException(
|
||||
f"No values specified for filter {rr_filter['Name']}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _matches_all_filters(entity, filters):
|
||||
"""Return True if this entity has fields matching all the filters."""
|
||||
for rr_filter in filters:
|
||||
field_value = getattr(entity, rr_filter["Field"])
|
||||
|
||||
if isinstance(field_value, list):
|
||||
if not set(field_value).intersection(rr_filter["Values"]):
|
||||
return False
|
||||
elif isinstance(field_value, int):
|
||||
if str(field_value) not in rr_filter["Values"]:
|
||||
return False
|
||||
elif field_value not in rr_filter["Values"]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_resolver_endpoints(
|
||||
self, filters=None, next_token=None, max_results=None,
|
||||
self, filters, next_token=None, max_results=None,
|
||||
): # pylint: disable=unused-argument
|
||||
"""List all resolver endpoints, using filters if specified."""
|
||||
# TODO - check subsequent filters
|
||||
# TODO - validate name, values for filters
|
||||
return sorted(self.resolver_endpoints.values(), key=lambda x: x.name)
|
||||
if not filters:
|
||||
filters = []
|
||||
|
||||
self._add_field_name_to_filter(filters)
|
||||
self._validate_filters(filters, ResolverEndpoint.FILTER_NAMES)
|
||||
|
||||
endpoints = []
|
||||
for endpoint in sorted(self.resolver_endpoints.values(), key=lambda x: x.name):
|
||||
if self._matches_all_filters(endpoint, filters):
|
||||
endpoints.append(endpoint)
|
||||
return endpoints
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_tags_for_resource(
|
||||
|
@ -88,7 +88,7 @@ class Route53ResolverResponse(BaseResponse):
|
||||
endpoints,
|
||||
next_token,
|
||||
) = self.route53resolver_backend.list_resolver_endpoints(
|
||||
filters=filters, next_token=next_token, max_results=max_results
|
||||
filters, next_token=next_token, max_results=max_results
|
||||
)
|
||||
except InvalidToken as exc:
|
||||
raise InvalidNextTokenException() from exc
|
||||
|
@ -670,6 +670,121 @@ def test_route53resolver_list_resolver_endpoints():
|
||||
assert endpoint["Name"].startswith(f"A{idx + 1}")
|
||||
|
||||
|
||||
@mock_ec2
|
||||
@mock_route53resolver
|
||||
def test_route53resolver_list_resolver_endpoints_filters():
|
||||
"""Test good list_resolver_endpoint API calls that use filters."""
|
||||
client = boto3.client("route53resolver", region_name=TEST_REGION)
|
||||
ec2_client = boto3.client("ec2", region_name=TEST_REGION)
|
||||
random_num = get_random_hex(10)
|
||||
|
||||
# Create some endpoints for testing purposes
|
||||
security_group_id = create_security_group(ec2_client)
|
||||
vpc_id = create_vpc(ec2_client)
|
||||
subnet_ids = create_subnets(ec2_client, vpc_id)
|
||||
ip0_values = ["10.0.1.201", "10.0.1.202", "10.0.1.203", "10.0.1.204"]
|
||||
ip1_values = ["10.0.0.21", "10.0.0.22", "10.0.0.23", "10.0.0.24"]
|
||||
endpoints = []
|
||||
for idx in range(1, 5):
|
||||
ip_addrs = [
|
||||
{"SubnetId": subnet_ids[0], "Ip": "10.0.1.200"},
|
||||
{"SubnetId": subnet_ids[1], "Ip": "10.0.0.20"},
|
||||
{"SubnetId": subnet_ids[0], "Ip": ip0_values[idx - 1]},
|
||||
{"SubnetId": subnet_ids[1], "Ip": ip1_values[idx - 1]},
|
||||
]
|
||||
response = client.create_resolver_endpoint(
|
||||
CreatorRequestId=f"F{idx}-{random_num}",
|
||||
Name=f"F{idx}-{random_num}",
|
||||
SecurityGroupIds=[security_group_id],
|
||||
Direction="INBOUND" if idx % 2 else "OUTBOUND",
|
||||
IpAddresses=ip_addrs,
|
||||
)
|
||||
endpoints.append(response["ResolverEndpoint"])
|
||||
|
||||
# Try all the valid filter names, including some of the old style names.
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[{"Name": "CreatorRequestId", "Values": [f"F3-{random_num}"]}]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 1
|
||||
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F3-{random_num}"
|
||||
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[
|
||||
{
|
||||
"Name": "CREATOR_REQUEST_ID",
|
||||
"Values": [f"F2-{random_num}", f"F4-{random_num}"],
|
||||
}
|
||||
]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 2
|
||||
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F2-{random_num}"
|
||||
assert response["ResolverEndpoints"][1]["CreatorRequestId"] == f"F4-{random_num}"
|
||||
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[{"Name": "Direction", "Values": ["INBOUND"]}]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 2
|
||||
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F1-{random_num}"
|
||||
assert response["ResolverEndpoints"][1]["CreatorRequestId"] == f"F3-{random_num}"
|
||||
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[{"Name": "HostVPCId", "Values": [vpc_id]}]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 4
|
||||
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[{"Name": "IpAddressCount", "Values": ["4"]}]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 4
|
||||
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[{"Name": "Name", "Values": [f"F1-{random_num}"]}]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 1
|
||||
assert response["ResolverEndpoints"][0]["Name"] == f"F1-{random_num}"
|
||||
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[
|
||||
{"Name": "HOST_VPC_ID", "Values": [vpc_id]},
|
||||
{"Name": "DIRECTION", "Values": ["INBOUND"]},
|
||||
{"Name": "NAME", "Values": [f"F3-{random_num}"]},
|
||||
]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 1
|
||||
assert response["ResolverEndpoints"][0]["Name"] == f"F3-{random_num}"
|
||||
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[{"Name": "SecurityGroupIds", "Values": [security_group_id]}]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 4
|
||||
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[{"Name": "Status", "Values": ["OPERATIONAL"]}]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 4
|
||||
response = client.list_resolver_endpoints(
|
||||
Filters=[{"Name": "Status", "Values": ["CREATING"]}]
|
||||
)
|
||||
assert len(response["ResolverEndpoints"]) == 0
|
||||
|
||||
|
||||
@mock_route53resolver
|
||||
def test_route53resolver_bad_list_resolver_endpoints_filters():
|
||||
"""Test bad list_resolver_endpoint API calls that use filters."""
|
||||
client = boto3.client("route53resolver", region_name=TEST_REGION)
|
||||
|
||||
# botocore barfs on an empty "Values":
|
||||
# TypeError: list_resolver_endpoints() only accepts keyword arguments.
|
||||
# client.list_resolver_endpoints([{"Name": "Direction", "Values": []}])
|
||||
# client.list_resolver_endpoints([{"Values": []}])
|
||||
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.list_resolver_endpoints(Filters=[{"Name": "foo", "Values": ["bar"]}])
|
||||
err = exc.value.response["Error"]
|
||||
assert err["Code"] == "InvalidParameterException"
|
||||
assert "The filter 'foo' is invalid" in err["Message"]
|
||||
|
||||
|
||||
@mock_ec2
|
||||
@mock_route53resolver
|
||||
def test_route53resolver_bad_list_resolver_endpoints():
|
||||
|
Loading…
Reference in New Issue
Block a user