refactor(events): Improve put_rule and event pattern (#4158)

This commit is contained in:
Gonzalo Saad 2021-08-13 02:01:44 -03:00 committed by GitHub
parent 67199d9828
commit 9e61ab2220
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 242 additions and 172 deletions

View File

@ -12,9 +12,13 @@ class IllegalStatusException(JsonRESTError):
class InvalidEventPatternException(JsonRESTError):
code = 400
def __init__(self):
def __init__(self, reason=None):
msg = "Event pattern is not valid. "
if reason:
msg += f"Reason: {reason}"
super(InvalidEventPatternException, self).__init__(
"InvalidEventPatternException", "Event pattern is not valid."
"InvalidEventPatternException", msg
)

View File

@ -30,18 +30,30 @@ from uuid import uuid4
class Rule(CloudFormationModel):
Arn = namedtuple("Arn", ["service", "resource_type", "resource_id"])
def __init__(self, name, region_name, **kwargs):
def __init__(
self,
name,
region_name,
description,
event_pattern,
schedule_exp,
role_arn,
event_bus_name,
state,
managed_by=None,
targets=None,
):
self.name = name
self.region_name = region_name
self.event_pattern = EventPattern(kwargs.get("EventPattern"))
self.schedule_exp = kwargs.get("ScheduleExpression")
self.state = kwargs.get("State") or "ENABLED"
self.description = kwargs.get("Description")
self.role_arn = kwargs.get("RoleArn")
self.managed_by = kwargs.get("ManagedBy") # can only be set by AWS services
self.event_bus_name = kwargs.get("EventBusName")
self.description = description
self.event_pattern = EventPattern.load(event_pattern)
self.scheduled_expression = schedule_exp
self.role_arn = role_arn
self.event_bus_name = event_bus_name
self.state = state or "ENABLED"
self.managed_by = managed_by # can only be set by AWS services
self.created_by = ACCOUNT_ID
self.targets = []
self.targets = targets or []
@property
def arn(self):
@ -190,7 +202,8 @@ class Rule(CloudFormationModel):
)
if queue_attr["ContentBasedDeduplication"] == "false":
warnings.warn(
"To let EventBridge send messages to your SQS FIFO queue, you must enable content-based deduplication."
"To let EventBridge send messages to your SQS FIFO queue, "
"you must enable content-based deduplication."
)
return
@ -223,10 +236,27 @@ class Rule(CloudFormationModel):
):
properties = cloudformation_json["Properties"]
properties.setdefault("EventBusName", "default")
event_backend = events_backends[region_name]
event_name = resource_name
return event_backend.put_rule(name=event_name, **properties)
event_pattern = properties.get("EventPattern")
scheduled_expression = properties.get("ScheduleExpression")
state = properties.get("State")
desc = properties.get("Description")
role_arn = properties.get("RoleArn")
event_bus_name = properties.get("EventBusName")
tags = properties.get("Tags")
backend = events_backends[region_name]
return backend.put_rule(
event_name,
scheduled_expression=scheduled_expression,
event_pattern=event_pattern,
state=state,
description=desc,
role_arn=role_arn,
event_bus_name=event_bus_name,
tags=tags,
)
@classmethod
def update_from_cloudformation_json(
@ -245,6 +275,24 @@ class Rule(CloudFormationModel):
event_name = resource_name
event_backend.delete_rule(name=event_name)
def describe(self):
attributes = {
"Arn": self.arn,
"CreatedBy": self.created_by,
"Description": self.description,
"EventBusName": self.event_bus_name,
"EventPattern": self.event_pattern.dump(),
"ManagedBy": self.managed_by,
"Name": self.name,
"RoleArn": self.role_arn,
"ScheduleExpression": self.scheduled_expression,
"State": self.state,
}
attributes = {
attr: value for attr, value in attributes.items() if value is not None
}
return attributes
class EventBus(CloudFormationModel):
def __init__(self, region_name, name, tags=None):
@ -426,7 +474,7 @@ class Archive(CloudFormationModel):
self.name = name
self.source_arn = source_arn
self.description = description
self.event_pattern = EventPattern(event_pattern)
self.event_pattern = EventPattern.load(event_pattern)
self.retention = retention if retention else 0
self.creation_time = unix_time(datetime.utcnow())
@ -457,7 +505,7 @@ class Archive(CloudFormationModel):
result = {
"ArchiveArn": self.arn,
"Description": self.description,
"EventPattern": str(self.event_pattern),
"EventPattern": self.event_pattern.dump(),
}
result.update(self.describe_short())
@ -467,7 +515,7 @@ class Archive(CloudFormationModel):
if description:
self.description = description
if event_pattern:
self.event_pattern = EventPattern(event_pattern)
self.event_pattern = EventPattern.load(event_pattern)
if retention:
self.retention = retention
@ -644,7 +692,7 @@ class Connection(BaseModel):
- The original response also has
- LastAuthorizedTime (number)
- LastModifiedTime (number)
- At the time of implemeting this, there was no place where to set/get
- At the time of implementing this, there was no place where to set/get
those attributes. That is why they are not in the response.
Returns:
@ -669,7 +717,7 @@ class Connection(BaseModel):
- LastModifiedTime (number)
- SecretArn (string)
- StateReason (string)
- At the time of implemeting this, there was no place where to set/get
- At the time of implementing this, there was no place where to set/get
those attributes. That is why they are not in the response.
Returns:
@ -751,46 +799,18 @@ class Destination(BaseModel):
class EventPattern:
def __init__(self, filter):
self._filter = self._load_event_pattern(filter)
self._filter_raw = filter
if not self._validate_event_pattern(self._filter):
raise InvalidEventPatternException
def __str__(self):
return self._filter_raw or str()
def _load_event_pattern(self, pattern):
try:
return json.loads(pattern) if pattern else None
except ValueError:
raise InvalidEventPatternException
def _validate_event_pattern(self, pattern):
# values in the event pattern have to be either a dict or an array
if pattern is None:
return True
dicts_valid = [
self._validate_event_pattern(value)
for value in pattern.values()
if isinstance(value, dict)
]
non_dicts_valid = [
isinstance(value, list)
for value in pattern.values()
if not isinstance(value, dict)
]
return all(dicts_valid) and all(non_dicts_valid)
def __init__(self, raw_pattern, pattern):
self._raw_pattern = raw_pattern
self._pattern = pattern
def matches_event(self, event):
if not self._filter:
if not self._pattern:
return True
event = json.loads(json.dumps(event))
return self._does_event_match(event, self._filter)
return self._does_event_match(event, self._pattern)
def _does_event_match(self, event, filter):
items_and_filters = [(event.get(k), v) for k, v in filter.items()]
def _does_event_match(self, event, pattern):
items_and_filters = [(event.get(k), v) for k, v in pattern.items()]
nested_filter_matches = [
self._does_event_match(item, nested_filter)
for item, nested_filter in items_and_filters
@ -807,14 +827,15 @@ class EventPattern:
allowed_values = [value for value in filters if isinstance(value, str)]
allowed_values_match = item in allowed_values if allowed_values else True
named_filter_matches = [
self._does_item_match_named_filter(item, filter)
for filter in filters
if isinstance(filter, dict)
self._does_item_match_named_filter(item, pattern)
for pattern in filters
if isinstance(pattern, dict)
]
return allowed_values_match and all(named_filter_matches)
def _does_item_match_named_filter(self, item, filter):
filter_name, filter_value = list(filter.items())[0]
@staticmethod
def _does_item_match_named_filter(item, pattern):
filter_name, filter_value = list(pattern.items())[0]
if filter_name == "exists":
is_leaf_node = not isinstance(item, dict)
leaf_exists = is_leaf_node and item is not None
@ -839,10 +860,49 @@ class EventPattern:
)
return True
@classmethod
def load(cls, raw_pattern):
parser = EventPatternParser(raw_pattern)
pattern = parser.parse()
return cls(raw_pattern, pattern)
def dump(self):
return self._raw_pattern
class EventPatternParser:
def __init__(self, pattern):
self.pattern = pattern
def _validate_event_pattern(self, pattern):
# values in the event pattern have to be either a dict or an array
for attr, value in pattern.items():
if isinstance(value, dict):
self._validate_event_pattern(value)
elif isinstance(value, list):
if len(value) == 0:
raise InvalidEventPatternException(
reason="Empty arrays are not allowed"
)
else:
raise InvalidEventPatternException(
reason=f"'{attr}' must be an object or an array"
)
def parse(self):
try:
parsed_pattern = json.loads(self.pattern) if self.pattern else dict()
self._validate_event_pattern(parsed_pattern)
return parsed_pattern
except JSONDecodeError:
raise InvalidEventPatternException(reason="Invalid JSON")
class EventsBackend(BaseBackend):
ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$")
STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$")
_CRON_REGEX = re.compile(r"^cron\(.*\)")
_RATE_REGEX = re.compile(r"^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)")
def __init__(self, region_name):
self.rules = {}
@ -911,6 +971,61 @@ class EventsBackend(BaseBackend):
return replay
def put_rule(
self,
name,
*,
description=None,
event_bus_name=None,
event_pattern=None,
role_arn=None,
scheduled_expression=None,
state=None,
managed_by=None,
tags=None,
):
event_bus_name = event_bus_name or "default"
if not event_pattern and not scheduled_expression:
raise JsonRESTError(
"ValidationException",
"Parameter(s) EventPattern or ScheduleExpression must be specified.",
)
if scheduled_expression:
if event_bus_name != "default":
raise ValidationException(
"ScheduleExpression is supported only on the default event bus."
)
if not (
self._CRON_REGEX.match(scheduled_expression)
or self._RATE_REGEX.match(scheduled_expression)
):
raise ValidationException("Parameter ScheduleExpression is not valid.")
existing_rule = self.rules.get(name)
targets = existing_rule.targets if existing_rule else list()
rule = Rule(
name,
self.region_name,
description,
event_pattern,
scheduled_expression,
role_arn,
event_bus_name,
state,
managed_by,
targets=targets,
)
self.rules[name] = rule
self.rules_order.append(name)
if tags:
self.tagger.tag_resource(rule.arn, tags)
return rule
def delete_rule(self, name):
self.rules_order.pop(self.rules_order.index(name))
arn = self.rules.get(name).arn
@ -919,7 +1034,10 @@ class EventsBackend(BaseBackend):
return self.rules.pop(name) is not None
def describe_rule(self, name):
return self.rules.get(name)
rule = self.rules.get(name)
if not rule:
raise ResourceNotFoundException("Rule {} does not exist.".format(name))
return rule
def disable_rule(self, name):
if name in self.rules:
@ -1001,28 +1119,6 @@ class EventsBackend(BaseBackend):
return return_obj
def update_rule(self, rule, **kwargs):
rule.event_pattern = kwargs.get("EventPattern") or rule.event_pattern
rule.schedule_exp = kwargs.get("ScheduleExpression") or rule.schedule_exp
rule.state = kwargs.get("State") or rule.state
rule.description = kwargs.get("Description") or rule.description
rule.role_arn = kwargs.get("RoleArn") or rule.role_arn
rule.event_bus_name = kwargs.get("EventBusName") or rule.event_bus_name
def put_rule(self, name, **kwargs):
if kwargs.get("ScheduleExpression") and kwargs.get("EventBusName") != "default":
raise ValidationException(
"ScheduleExpression is supported only on the default event bus."
)
if name in self.rules:
self.update_rule(self.rules[name], **kwargs)
new_rule = self.rules[name]
else:
new_rule = Rule(name, self.region_name, **kwargs)
self.rules[new_rule.name] = new_rule
self.rules_order.append(new_rule.name)
return new_rule
def put_targets(self, name, event_bus_name, targets):
# super simple ARN check
invalid_arn = next(
@ -1342,13 +1438,12 @@ class EventsBackend(BaseBackend):
rule_event_pattern = json.loads(event_pattern or "{}")
rule_event_pattern["replay-name"] = [{"exists": False}]
rule_name = "Events-Archive-{}".format(name)
rule = self.put_rule(
"Events-Archive-{}".format(name),
**{
"EventPattern": json.dumps(rule_event_pattern),
"EventBusName": event_bus.name,
"ManagedBy": "prod.vhs.events.aws.internal",
},
rule_name,
event_pattern=json.dumps(rule_event_pattern),
event_bus_name=event_bus.name,
managed_by="prod.vhs.events.aws.internal",
)
self.put_targets(
rule.name,

View File

@ -1,5 +1,4 @@
import json
import re
from moto.core.responses import BaseResponse
from moto.events import events_backends
@ -16,20 +15,6 @@ class EventsHandler(BaseResponse):
"""
return events_backends[self.region]
def _generate_rule_dict(self, rule):
return {
"Name": rule.name,
"Arn": rule.arn,
"EventPattern": str(rule.event_pattern),
"State": rule.state,
"Description": rule.description,
"ScheduleExpression": rule.schedule_exp,
"RoleArn": rule.role_arn,
"ManagedBy": rule.managed_by,
"EventBusName": rule.event_bus_name,
"CreatedBy": rule.created_by,
}
@property
def request_params(self):
if not hasattr(self, "_json_body"):
@ -61,6 +46,29 @@ class EventsHandler(BaseResponse):
headers["status"] = status
return json.dumps({"__type": type_, "message": message}), headers
def put_rule(self):
name = self._get_param("Name")
event_pattern = self._get_param("EventPattern")
scheduled_expression = self._get_param("ScheduleExpression")
state = self._get_param("State")
desc = self._get_param("Description")
role_arn = self._get_param("RoleArn")
event_bus_name = self._get_param("EventBusName")
tags = self._get_param("Tags")
rule = self.events_backend.put_rule(
name,
scheduled_expression=scheduled_expression,
event_pattern=event_pattern,
state=state,
description=desc,
role_arn=role_arn,
event_bus_name=event_bus_name,
tags=tags,
)
result = {"RuleArn": rule.arn}
return self._create_response(result)
def delete_rule(self):
name = self._get_param("Name")
@ -78,13 +86,8 @@ class EventsHandler(BaseResponse):
rule = self.events_backend.describe_rule(name)
if not rule:
return self.error(
"ResourceNotFoundException", "Rule " + name + " does not exist."
)
rule_dict = self._generate_rule_dict(rule)
return json.dumps(rule_dict), self.response_headers
result = rule.describe()
return self._create_response(result)
def disable_rule(self):
name = self._get_param("Name")
@ -138,7 +141,7 @@ class EventsHandler(BaseResponse):
rules_obj = {"Rules": []}
for rule in rules["Rules"]:
rules_obj["Rules"].append(self._generate_rule_dict(rule))
rules_obj["Rules"].append(rule.describe())
if rules.get("NextToken"):
rules_obj["NextToken"] = rules["NextToken"]
@ -177,48 +180,6 @@ class EventsHandler(BaseResponse):
return json.dumps(response)
def put_rule(self):
name = self._get_param("Name")
event_pattern = self._get_param("EventPattern")
sched_exp = self._get_param("ScheduleExpression")
state = self._get_param("State")
desc = self._get_param("Description")
role_arn = self._get_param("RoleArn")
event_bus_name = self._get_param("EventBusName", "default")
if event_pattern:
try:
json.loads(event_pattern)
except ValueError:
# Not quite as informative as the real error, but it'll work
# for now.
return self.error(
"InvalidEventPatternException", "Event pattern is not valid."
)
if sched_exp:
if not (
re.match(r"^cron\(.*\)", sched_exp)
or re.match(
r"^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)", sched_exp
)
):
return self.error(
"ValidationException", "Parameter ScheduleExpression is not valid."
)
rule = self.events_backend.put_rule(
name,
ScheduleExpression=sched_exp,
EventPattern=event_pattern,
State=state,
Description=desc,
RoleArn=role_arn,
EventBusName=event_bus_name,
)
return json.dumps({"RuleArn": rule.arn}), self.response_headers
def put_targets(self):
rule_name = self._get_param("Rule")
event_bus_name = self._get_param("EventBusName", "default")

View File

@ -10,6 +10,7 @@ TestAccAWSCloudWatchEventBus
TestAccAWSCloudwatchEventBusPolicy
TestAccAWSCloudWatchEventConnection
TestAccAWSCloudWatchEventPermission
TestAccAWSCloudWatchEventRule
TestAccAWSCloudwatchLogGroupDataSource
TestAccAWSDataSourceCloudwatch
TestAccAWSDataSourceElasticBeanstalkHostedZone

View File

@ -6,41 +6,45 @@ from moto.events.models import EventPattern
def test_event_pattern_with_allowed_values_event_filter():
pattern = EventPattern(json.dumps({"source": ["foo", "bar"]}))
pattern = EventPattern.load(json.dumps({"source": ["foo", "bar"]}))
assert pattern.matches_event({"source": "foo"})
assert pattern.matches_event({"source": "bar"})
assert not pattern.matches_event({"source": "baz"})
def test_event_pattern_with_nested_event_filter():
pattern = EventPattern(json.dumps({"detail": {"foo": ["bar"]}}))
pattern = EventPattern.load(json.dumps({"detail": {"foo": ["bar"]}}))
assert pattern.matches_event({"detail": {"foo": "bar"}})
assert not pattern.matches_event({"detail": {"foo": "baz"}})
def test_event_pattern_with_exists_event_filter():
foo_exists = EventPattern(json.dumps({"detail": {"foo": [{"exists": True}]}}))
foo_exists = EventPattern.load(json.dumps({"detail": {"foo": [{"exists": True}]}}))
assert foo_exists.matches_event({"detail": {"foo": "bar"}})
assert not foo_exists.matches_event({"detail": {}})
# exists filters only match leaf nodes of an event
assert not foo_exists.matches_event({"detail": {"foo": {"bar": "baz"}}})
foo_not_exists = EventPattern(json.dumps({"detail": {"foo": [{"exists": False}]}}))
foo_not_exists = EventPattern.load(
json.dumps({"detail": {"foo": [{"exists": False}]}})
)
assert not foo_not_exists.matches_event({"detail": {"foo": "bar"}})
assert foo_not_exists.matches_event({"detail": {}})
assert foo_not_exists.matches_event({"detail": {"foo": {"bar": "baz"}}})
bar_exists = EventPattern(json.dumps({"detail": {"bar": [{"exists": True}]}}))
bar_exists = EventPattern.load(json.dumps({"detail": {"bar": [{"exists": True}]}}))
assert not bar_exists.matches_event({"detail": {"foo": "bar"}})
assert not bar_exists.matches_event({"detail": {}})
bar_not_exists = EventPattern(json.dumps({"detail": {"bar": [{"exists": False}]}}))
bar_not_exists = EventPattern.load(
json.dumps({"detail": {"bar": [{"exists": False}]}})
)
assert bar_not_exists.matches_event({"detail": {"foo": "bar"}})
assert bar_not_exists.matches_event({"detail": {}})
def test_event_pattern_with_prefix_event_filter():
pattern = EventPattern(json.dumps({"detail": {"foo": [{"prefix": "bar"}]}}))
pattern = EventPattern.load(json.dumps({"detail": {"foo": [{"prefix": "bar"}]}}))
assert pattern.matches_event({"detail": {"foo": "bar"}})
assert pattern.matches_event({"detail": {"foo": "bar!"}})
assert not pattern.matches_event({"detail": {"foo": "ba"}})
@ -59,7 +63,7 @@ def test_event_pattern_with_prefix_event_filter():
def test_event_pattern_with_single_numeric_event_filter(
operator, compare_to, should_match, should_not_match
):
pattern = EventPattern(
pattern = EventPattern.load(
json.dumps({"detail": {"foo": [{"numeric": [operator, compare_to]}]}})
)
for number in should_match:
@ -71,7 +75,7 @@ def test_event_pattern_with_single_numeric_event_filter(
def test_event_pattern_with_multi_numeric_event_filter():
events = [{"detail": {"foo": number}} for number in range(5)]
one_or_two = EventPattern(
one_or_two = EventPattern.load(
json.dumps({"detail": {"foo": [{"numeric": [">=", 1, "<", 3]}]}})
)
assert not one_or_two.matches_event(events[0])
@ -80,7 +84,7 @@ def test_event_pattern_with_multi_numeric_event_filter():
assert not one_or_two.matches_event(events[3])
assert not one_or_two.matches_event(events[4])
two_or_three = EventPattern(
two_or_three = EventPattern.load(
json.dumps({"detail": {"foo": [{"numeric": [">", 1, "<=", 3]}]}})
)
assert not two_or_three.matches_event(events[0])
@ -92,8 +96,8 @@ def test_event_pattern_with_multi_numeric_event_filter():
@pytest.mark.parametrize(
"pattern, expected_str",
[('{"source": ["foo", "bar"]}', '{"source": ["foo", "bar"]}'), (None, ""),],
[('{"source": ["foo", "bar"]}', '{"source": ["foo", "bar"]}'), (None, None),],
)
def test_event_pattern_str(pattern, expected_str):
event_pattern = EventPattern(pattern)
assert str(event_pattern) == expected_str
def test_event_pattern_dump(pattern, expected_str):
event_pattern = EventPattern.load(pattern)
assert event_pattern.dump() == expected_str

View File

@ -1050,7 +1050,9 @@ def test_create_archive_error_invalid_event_pattern():
ex.operation_name.should.equal("CreateArchive")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidEventPatternException")
ex.response["Error"]["Message"].should.equal("Event pattern is not valid.")
ex.response["Error"]["Message"].should.equal(
"Event pattern is not valid. Reason: Invalid JSON"
)
@mock_events
@ -1080,7 +1082,9 @@ def test_create_archive_error_invalid_event_pattern_not_an_array():
ex.operation_name.should.equal("CreateArchive")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidEventPatternException")
ex.response["Error"]["Message"].should.equal("Event pattern is not valid.")
ex.response["Error"]["Message"].should.equal(
"Event pattern is not valid. Reason: 'key_6' must be an object or an array"
)
@mock_events
@ -1378,7 +1382,9 @@ def test_update_archive_error_invalid_event_pattern():
ex.operation_name.should.equal("UpdateArchive")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidEventPatternException")
ex.response["Error"]["Message"].should.equal("Event pattern is not valid.")
ex.response["Error"]["Message"].should.equal(
"Event pattern is not valid. Reason: Invalid JSON"
)
@mock_events
@ -2399,7 +2405,6 @@ def test_delete_connection_success():
response = client.delete_connection(Name=conn_name)
# Then
expected_arn = f"arn:aws:events:eu-central-1:{ACCOUNT_ID}:connection/{conn_name}/"
assert response["ConnectionArn"] == created_connection["ConnectionArn"]
assert response["ConnectionState"] == created_connection["ConnectionState"]
assert response["CreationTime"] == created_connection["CreationTime"]