Add support for filters for list_resolver_endpoints (#4598)
This commit is contained in:
parent
32b2a90ee6
commit
be197caba6
@ -2,6 +2,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_address, ip_network
|
||||||
|
import re
|
||||||
|
|
||||||
from boto3 import Session
|
from boto3 import Session
|
||||||
|
|
||||||
@ -25,6 +26,8 @@ from moto.route53resolver.validations import validate_args
|
|||||||
from moto.utilities.paginator import paginate
|
from moto.utilities.paginator import paginate
|
||||||
from moto.utilities.tagging_service import TaggingService
|
from moto.utilities.tagging_service import TaggingService
|
||||||
|
|
||||||
|
CAMEL_TO_SNAKE_PATTERN = re.compile(r"(?<!^)(?=[A-Z])")
|
||||||
|
|
||||||
|
|
||||||
class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attributes
|
class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attributes
|
||||||
"""Representation of a fake Route53 Resolver Endpoint."""
|
"""Representation of a fake Route53 Resolver Endpoint."""
|
||||||
@ -32,6 +35,18 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
|
|||||||
MAX_TAGS_PER_RESOLVER_ENDPOINT = 200
|
MAX_TAGS_PER_RESOLVER_ENDPOINT = 200
|
||||||
MAX_ENDPOINTS_PER_REGION = 4
|
MAX_ENDPOINTS_PER_REGION = 4
|
||||||
|
|
||||||
|
# There are two styles of filter names and either will be transformed
|
||||||
|
# into lowercase snake.
|
||||||
|
FILTER_NAMES = [
|
||||||
|
"creator_request_id",
|
||||||
|
"direction",
|
||||||
|
"host_vpc_id",
|
||||||
|
"ip_address_count",
|
||||||
|
"name",
|
||||||
|
"security_group_ids",
|
||||||
|
"status",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
region,
|
region,
|
||||||
@ -328,14 +343,71 @@ class Route53ResolverBackend(BaseBackend):
|
|||||||
endpoint = self.resolver_endpoints[resolver_endpoint_id]
|
endpoint = self.resolver_endpoints[resolver_endpoint_id]
|
||||||
return endpoint.ip_descriptions()
|
return endpoint.ip_descriptions()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_field_name_to_filter(filters):
|
||||||
|
"""Convert both styles of filter names to lowercase snake format.
|
||||||
|
|
||||||
|
"IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count".
|
||||||
|
However, "HostVPCId" doesn't fit the pattern, so that's treated
|
||||||
|
special.
|
||||||
|
"""
|
||||||
|
for rr_filter in filters:
|
||||||
|
filter_name = rr_filter["Name"]
|
||||||
|
if "_" not in filter_name:
|
||||||
|
if filter_name == "HostVPCId":
|
||||||
|
filter_name = "host_vpc_id"
|
||||||
|
elif filter_name == "HostVpcId":
|
||||||
|
filter_name = "WRONG"
|
||||||
|
elif not filter_name.isupper():
|
||||||
|
filter_name = CAMEL_TO_SNAKE_PATTERN.sub("_", filter_name)
|
||||||
|
rr_filter["Field"] = filter_name.lower()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_filters(filters, allowed_filter_names):
|
||||||
|
"""Raise exception if filter names are not as expected."""
|
||||||
|
for rr_filter in filters:
|
||||||
|
if rr_filter["Field"] not in allowed_filter_names:
|
||||||
|
raise InvalidParameterException(
|
||||||
|
f"The filter '{rr_filter['Name']}' is invalid"
|
||||||
|
)
|
||||||
|
if "Values" not in rr_filter:
|
||||||
|
raise InvalidParameterException(
|
||||||
|
f"No values specified for filter {rr_filter['Name']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _matches_all_filters(entity, filters):
|
||||||
|
"""Return True if this entity has fields matching all the filters."""
|
||||||
|
for rr_filter in filters:
|
||||||
|
field_value = getattr(entity, rr_filter["Field"])
|
||||||
|
|
||||||
|
if isinstance(field_value, list):
|
||||||
|
if not set(field_value).intersection(rr_filter["Values"]):
|
||||||
|
return False
|
||||||
|
elif isinstance(field_value, int):
|
||||||
|
if str(field_value) not in rr_filter["Values"]:
|
||||||
|
return False
|
||||||
|
elif field_value not in rr_filter["Values"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
@paginate(pagination_model=PAGINATION_MODEL)
|
@paginate(pagination_model=PAGINATION_MODEL)
|
||||||
def list_resolver_endpoints(
|
def list_resolver_endpoints(
|
||||||
self, filters=None, next_token=None, max_results=None,
|
self, filters, next_token=None, max_results=None,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
"""List all resolver endpoints, using filters if specified."""
|
"""List all resolver endpoints, using filters if specified."""
|
||||||
# TODO - check subsequent filters
|
if not filters:
|
||||||
# TODO - validate name, values for filters
|
filters = []
|
||||||
return sorted(self.resolver_endpoints.values(), key=lambda x: x.name)
|
|
||||||
|
self._add_field_name_to_filter(filters)
|
||||||
|
self._validate_filters(filters, ResolverEndpoint.FILTER_NAMES)
|
||||||
|
|
||||||
|
endpoints = []
|
||||||
|
for endpoint in sorted(self.resolver_endpoints.values(), key=lambda x: x.name):
|
||||||
|
if self._matches_all_filters(endpoint, filters):
|
||||||
|
endpoints.append(endpoint)
|
||||||
|
return endpoints
|
||||||
|
|
||||||
@paginate(pagination_model=PAGINATION_MODEL)
|
@paginate(pagination_model=PAGINATION_MODEL)
|
||||||
def list_tags_for_resource(
|
def list_tags_for_resource(
|
||||||
|
@ -88,7 +88,7 @@ class Route53ResolverResponse(BaseResponse):
|
|||||||
endpoints,
|
endpoints,
|
||||||
next_token,
|
next_token,
|
||||||
) = self.route53resolver_backend.list_resolver_endpoints(
|
) = self.route53resolver_backend.list_resolver_endpoints(
|
||||||
filters=filters, next_token=next_token, max_results=max_results
|
filters, next_token=next_token, max_results=max_results
|
||||||
)
|
)
|
||||||
except InvalidToken as exc:
|
except InvalidToken as exc:
|
||||||
raise InvalidNextTokenException() from exc
|
raise InvalidNextTokenException() from exc
|
||||||
|
@ -670,6 +670,121 @@ def test_route53resolver_list_resolver_endpoints():
|
|||||||
assert endpoint["Name"].startswith(f"A{idx + 1}")
|
assert endpoint["Name"].startswith(f"A{idx + 1}")
|
||||||
|
|
||||||
|
|
||||||
|
@mock_ec2
|
||||||
|
@mock_route53resolver
|
||||||
|
def test_route53resolver_list_resolver_endpoints_filters():
|
||||||
|
"""Test good list_resolver_endpoint API calls that use filters."""
|
||||||
|
client = boto3.client("route53resolver", region_name=TEST_REGION)
|
||||||
|
ec2_client = boto3.client("ec2", region_name=TEST_REGION)
|
||||||
|
random_num = get_random_hex(10)
|
||||||
|
|
||||||
|
# Create some endpoints for testing purposes
|
||||||
|
security_group_id = create_security_group(ec2_client)
|
||||||
|
vpc_id = create_vpc(ec2_client)
|
||||||
|
subnet_ids = create_subnets(ec2_client, vpc_id)
|
||||||
|
ip0_values = ["10.0.1.201", "10.0.1.202", "10.0.1.203", "10.0.1.204"]
|
||||||
|
ip1_values = ["10.0.0.21", "10.0.0.22", "10.0.0.23", "10.0.0.24"]
|
||||||
|
endpoints = []
|
||||||
|
for idx in range(1, 5):
|
||||||
|
ip_addrs = [
|
||||||
|
{"SubnetId": subnet_ids[0], "Ip": "10.0.1.200"},
|
||||||
|
{"SubnetId": subnet_ids[1], "Ip": "10.0.0.20"},
|
||||||
|
{"SubnetId": subnet_ids[0], "Ip": ip0_values[idx - 1]},
|
||||||
|
{"SubnetId": subnet_ids[1], "Ip": ip1_values[idx - 1]},
|
||||||
|
]
|
||||||
|
response = client.create_resolver_endpoint(
|
||||||
|
CreatorRequestId=f"F{idx}-{random_num}",
|
||||||
|
Name=f"F{idx}-{random_num}",
|
||||||
|
SecurityGroupIds=[security_group_id],
|
||||||
|
Direction="INBOUND" if idx % 2 else "OUTBOUND",
|
||||||
|
IpAddresses=ip_addrs,
|
||||||
|
)
|
||||||
|
endpoints.append(response["ResolverEndpoint"])
|
||||||
|
|
||||||
|
# Try all the valid filter names, including some of the old style names.
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[{"Name": "CreatorRequestId", "Values": [f"F3-{random_num}"]}]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 1
|
||||||
|
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F3-{random_num}"
|
||||||
|
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[
|
||||||
|
{
|
||||||
|
"Name": "CREATOR_REQUEST_ID",
|
||||||
|
"Values": [f"F2-{random_num}", f"F4-{random_num}"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 2
|
||||||
|
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F2-{random_num}"
|
||||||
|
assert response["ResolverEndpoints"][1]["CreatorRequestId"] == f"F4-{random_num}"
|
||||||
|
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[{"Name": "Direction", "Values": ["INBOUND"]}]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 2
|
||||||
|
assert response["ResolverEndpoints"][0]["CreatorRequestId"] == f"F1-{random_num}"
|
||||||
|
assert response["ResolverEndpoints"][1]["CreatorRequestId"] == f"F3-{random_num}"
|
||||||
|
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[{"Name": "HostVPCId", "Values": [vpc_id]}]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 4
|
||||||
|
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[{"Name": "IpAddressCount", "Values": ["4"]}]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 4
|
||||||
|
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[{"Name": "Name", "Values": [f"F1-{random_num}"]}]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 1
|
||||||
|
assert response["ResolverEndpoints"][0]["Name"] == f"F1-{random_num}"
|
||||||
|
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[
|
||||||
|
{"Name": "HOST_VPC_ID", "Values": [vpc_id]},
|
||||||
|
{"Name": "DIRECTION", "Values": ["INBOUND"]},
|
||||||
|
{"Name": "NAME", "Values": [f"F3-{random_num}"]},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 1
|
||||||
|
assert response["ResolverEndpoints"][0]["Name"] == f"F3-{random_num}"
|
||||||
|
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[{"Name": "SecurityGroupIds", "Values": [security_group_id]}]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 4
|
||||||
|
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[{"Name": "Status", "Values": ["OPERATIONAL"]}]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 4
|
||||||
|
response = client.list_resolver_endpoints(
|
||||||
|
Filters=[{"Name": "Status", "Values": ["CREATING"]}]
|
||||||
|
)
|
||||||
|
assert len(response["ResolverEndpoints"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@mock_route53resolver
|
||||||
|
def test_route53resolver_bad_list_resolver_endpoints_filters():
|
||||||
|
"""Test bad list_resolver_endpoint API calls that use filters."""
|
||||||
|
client = boto3.client("route53resolver", region_name=TEST_REGION)
|
||||||
|
|
||||||
|
# botocore barfs on an empty "Values":
|
||||||
|
# TypeError: list_resolver_endpoints() only accepts keyword arguments.
|
||||||
|
# client.list_resolver_endpoints([{"Name": "Direction", "Values": []}])
|
||||||
|
# client.list_resolver_endpoints([{"Values": []}])
|
||||||
|
|
||||||
|
with pytest.raises(ClientError) as exc:
|
||||||
|
client.list_resolver_endpoints(Filters=[{"Name": "foo", "Values": ["bar"]}])
|
||||||
|
err = exc.value.response["Error"]
|
||||||
|
assert err["Code"] == "InvalidParameterException"
|
||||||
|
assert "The filter 'foo' is invalid" in err["Message"]
|
||||||
|
|
||||||
|
|
||||||
@mock_ec2
|
@mock_ec2
|
||||||
@mock_route53resolver
|
@mock_route53resolver
|
||||||
def test_route53resolver_bad_list_resolver_endpoints():
|
def test_route53resolver_bad_list_resolver_endpoints():
|
||||||
|
Loading…
Reference in New Issue
Block a user