From 9e61ab22207454f69a84a7a9afaeef9eebafec43 Mon Sep 17 00:00:00 2001 From: Gonzalo Saad Date: Fri, 13 Aug 2021 02:01:44 -0300 Subject: [PATCH] refactor(events): Improve `put_rule` and event pattern (#4158) --- moto/events/exceptions.py | 8 +- moto/events/models.py | 269 ++++++++++++++++-------- moto/events/responses.py | 91 +++----- tests/terraform-tests.success.txt | 1 + tests/test_events/test_event_pattern.py | 32 +-- tests/test_events/test_events.py | 13 +- 6 files changed, 242 insertions(+), 172 deletions(-) diff --git a/moto/events/exceptions.py b/moto/events/exceptions.py index b077cc983..0952edd52 100644 --- a/moto/events/exceptions.py +++ b/moto/events/exceptions.py @@ -12,9 +12,13 @@ class IllegalStatusException(JsonRESTError): class InvalidEventPatternException(JsonRESTError): code = 400 - def __init__(self): + def __init__(self, reason=None): + msg = "Event pattern is not valid. " + if reason: + msg += f"Reason: {reason}" + super(InvalidEventPatternException, self).__init__( - "InvalidEventPatternException", "Event pattern is not valid." + "InvalidEventPatternException", msg ) diff --git a/moto/events/models.py b/moto/events/models.py index 883edbc90..858121d9f 100644 --- a/moto/events/models.py +++ b/moto/events/models.py @@ -30,18 +30,30 @@ from uuid import uuid4 class Rule(CloudFormationModel): Arn = namedtuple("Arn", ["service", "resource_type", "resource_id"]) - def __init__(self, name, region_name, **kwargs): + def __init__( + self, + name, + region_name, + description, + event_pattern, + schedule_exp, + role_arn, + event_bus_name, + state, + managed_by=None, + targets=None, + ): self.name = name self.region_name = region_name - self.event_pattern = EventPattern(kwargs.get("EventPattern")) - self.schedule_exp = kwargs.get("ScheduleExpression") - self.state = kwargs.get("State") or "ENABLED" - self.description = kwargs.get("Description") - self.role_arn = kwargs.get("RoleArn") - self.managed_by = kwargs.get("ManagedBy") # can only be set by AWS services - self.event_bus_name = kwargs.get("EventBusName") + self.description = description + self.event_pattern = EventPattern.load(event_pattern) + self.scheduled_expression = schedule_exp + self.role_arn = role_arn + self.event_bus_name = event_bus_name + self.state = state or "ENABLED" + self.managed_by = managed_by # can only be set by AWS services self.created_by = ACCOUNT_ID - self.targets = [] + self.targets = targets or [] @property def arn(self): @@ -190,7 +202,8 @@ class Rule(CloudFormationModel): ) if queue_attr["ContentBasedDeduplication"] == "false": warnings.warn( - "To let EventBridge send messages to your SQS FIFO queue, you must enable content-based deduplication." + "To let EventBridge send messages to your SQS FIFO queue, " + "you must enable content-based deduplication." ) return @@ -223,10 +236,27 @@ class Rule(CloudFormationModel): ): properties = cloudformation_json["Properties"] properties.setdefault("EventBusName", "default") - - event_backend = events_backends[region_name] event_name = resource_name - return event_backend.put_rule(name=event_name, **properties) + + event_pattern = properties.get("EventPattern") + scheduled_expression = properties.get("ScheduleExpression") + state = properties.get("State") + desc = properties.get("Description") + role_arn = properties.get("RoleArn") + event_bus_name = properties.get("EventBusName") + tags = properties.get("Tags") + + backend = events_backends[region_name] + return backend.put_rule( + event_name, + scheduled_expression=scheduled_expression, + event_pattern=event_pattern, + state=state, + description=desc, + role_arn=role_arn, + event_bus_name=event_bus_name, + tags=tags, + ) @classmethod def update_from_cloudformation_json( @@ -245,6 +275,24 @@ class Rule(CloudFormationModel): event_name = resource_name event_backend.delete_rule(name=event_name) + def describe(self): + attributes = { + "Arn": self.arn, + "CreatedBy": self.created_by, + "Description": self.description, + "EventBusName": self.event_bus_name, + "EventPattern": self.event_pattern.dump(), + "ManagedBy": self.managed_by, + "Name": self.name, + "RoleArn": self.role_arn, + "ScheduleExpression": self.scheduled_expression, + "State": self.state, + } + attributes = { + attr: value for attr, value in attributes.items() if value is not None + } + return attributes + class EventBus(CloudFormationModel): def __init__(self, region_name, name, tags=None): @@ -426,7 +474,7 @@ class Archive(CloudFormationModel): self.name = name self.source_arn = source_arn self.description = description - self.event_pattern = EventPattern(event_pattern) + self.event_pattern = EventPattern.load(event_pattern) self.retention = retention if retention else 0 self.creation_time = unix_time(datetime.utcnow()) @@ -457,7 +505,7 @@ class Archive(CloudFormationModel): result = { "ArchiveArn": self.arn, "Description": self.description, - "EventPattern": str(self.event_pattern), + "EventPattern": self.event_pattern.dump(), } result.update(self.describe_short()) @@ -467,7 +515,7 @@ class Archive(CloudFormationModel): if description: self.description = description if event_pattern: - self.event_pattern = EventPattern(event_pattern) + self.event_pattern = EventPattern.load(event_pattern) if retention: self.retention = retention @@ -644,7 +692,7 @@ class Connection(BaseModel): - The original response also has - LastAuthorizedTime (number) - LastModifiedTime (number) - - At the time of implemeting this, there was no place where to set/get + - At the time of implementing this, there was no place where to set/get those attributes. That is why they are not in the response. Returns: @@ -669,7 +717,7 @@ class Connection(BaseModel): - LastModifiedTime (number) - SecretArn (string) - StateReason (string) - - At the time of implemeting this, there was no place where to set/get + - At the time of implementing this, there was no place where to set/get those attributes. That is why they are not in the response. Returns: @@ -751,46 +799,18 @@ class Destination(BaseModel): class EventPattern: - def __init__(self, filter): - self._filter = self._load_event_pattern(filter) - self._filter_raw = filter - if not self._validate_event_pattern(self._filter): - raise InvalidEventPatternException - - def __str__(self): - return self._filter_raw or str() - - def _load_event_pattern(self, pattern): - try: - return json.loads(pattern) if pattern else None - except ValueError: - raise InvalidEventPatternException - - def _validate_event_pattern(self, pattern): - # values in the event pattern have to be either a dict or an array - if pattern is None: - return True - - dicts_valid = [ - self._validate_event_pattern(value) - for value in pattern.values() - if isinstance(value, dict) - ] - non_dicts_valid = [ - isinstance(value, list) - for value in pattern.values() - if not isinstance(value, dict) - ] - return all(dicts_valid) and all(non_dicts_valid) + def __init__(self, raw_pattern, pattern): + self._raw_pattern = raw_pattern + self._pattern = pattern def matches_event(self, event): - if not self._filter: + if not self._pattern: return True event = json.loads(json.dumps(event)) - return self._does_event_match(event, self._filter) + return self._does_event_match(event, self._pattern) - def _does_event_match(self, event, filter): - items_and_filters = [(event.get(k), v) for k, v in filter.items()] + def _does_event_match(self, event, pattern): + items_and_filters = [(event.get(k), v) for k, v in pattern.items()] nested_filter_matches = [ self._does_event_match(item, nested_filter) for item, nested_filter in items_and_filters @@ -807,14 +827,15 @@ class EventPattern: allowed_values = [value for value in filters if isinstance(value, str)] allowed_values_match = item in allowed_values if allowed_values else True named_filter_matches = [ - self._does_item_match_named_filter(item, filter) - for filter in filters - if isinstance(filter, dict) + self._does_item_match_named_filter(item, pattern) + for pattern in filters + if isinstance(pattern, dict) ] return allowed_values_match and all(named_filter_matches) - def _does_item_match_named_filter(self, item, filter): - filter_name, filter_value = list(filter.items())[0] + @staticmethod + def _does_item_match_named_filter(item, pattern): + filter_name, filter_value = list(pattern.items())[0] if filter_name == "exists": is_leaf_node = not isinstance(item, dict) leaf_exists = is_leaf_node and item is not None @@ -839,10 +860,49 @@ class EventPattern: ) return True + @classmethod + def load(cls, raw_pattern): + parser = EventPatternParser(raw_pattern) + pattern = parser.parse() + return cls(raw_pattern, pattern) + + def dump(self): + return self._raw_pattern + + +class EventPatternParser: + def __init__(self, pattern): + self.pattern = pattern + + def _validate_event_pattern(self, pattern): + # values in the event pattern have to be either a dict or an array + for attr, value in pattern.items(): + if isinstance(value, dict): + self._validate_event_pattern(value) + elif isinstance(value, list): + if len(value) == 0: + raise InvalidEventPatternException( + reason="Empty arrays are not allowed" + ) + else: + raise InvalidEventPatternException( + reason=f"'{attr}' must be an object or an array" + ) + + def parse(self): + try: + parsed_pattern = json.loads(self.pattern) if self.pattern else dict() + self._validate_event_pattern(parsed_pattern) + return parsed_pattern + except JSONDecodeError: + raise InvalidEventPatternException(reason="Invalid JSON") + class EventsBackend(BaseBackend): ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$") STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$") + _CRON_REGEX = re.compile(r"^cron\(.*\)") + _RATE_REGEX = re.compile(r"^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)") def __init__(self, region_name): self.rules = {} @@ -911,6 +971,61 @@ class EventsBackend(BaseBackend): return replay + def put_rule( + self, + name, + *, + description=None, + event_bus_name=None, + event_pattern=None, + role_arn=None, + scheduled_expression=None, + state=None, + managed_by=None, + tags=None, + ): + event_bus_name = event_bus_name or "default" + + if not event_pattern and not scheduled_expression: + raise JsonRESTError( + "ValidationException", + "Parameter(s) EventPattern or ScheduleExpression must be specified.", + ) + + if scheduled_expression: + if event_bus_name != "default": + raise ValidationException( + "ScheduleExpression is supported only on the default event bus." + ) + + if not ( + self._CRON_REGEX.match(scheduled_expression) + or self._RATE_REGEX.match(scheduled_expression) + ): + raise ValidationException("Parameter ScheduleExpression is not valid.") + + existing_rule = self.rules.get(name) + targets = existing_rule.targets if existing_rule else list() + rule = Rule( + name, + self.region_name, + description, + event_pattern, + scheduled_expression, + role_arn, + event_bus_name, + state, + managed_by, + targets=targets, + ) + self.rules[name] = rule + self.rules_order.append(name) + + if tags: + self.tagger.tag_resource(rule.arn, tags) + + return rule + def delete_rule(self, name): self.rules_order.pop(self.rules_order.index(name)) arn = self.rules.get(name).arn @@ -919,7 +1034,10 @@ class EventsBackend(BaseBackend): return self.rules.pop(name) is not None def describe_rule(self, name): - return self.rules.get(name) + rule = self.rules.get(name) + if not rule: + raise ResourceNotFoundException("Rule {} does not exist.".format(name)) + return rule def disable_rule(self, name): if name in self.rules: @@ -1001,28 +1119,6 @@ class EventsBackend(BaseBackend): return return_obj - def update_rule(self, rule, **kwargs): - rule.event_pattern = kwargs.get("EventPattern") or rule.event_pattern - rule.schedule_exp = kwargs.get("ScheduleExpression") or rule.schedule_exp - rule.state = kwargs.get("State") or rule.state - rule.description = kwargs.get("Description") or rule.description - rule.role_arn = kwargs.get("RoleArn") or rule.role_arn - rule.event_bus_name = kwargs.get("EventBusName") or rule.event_bus_name - - def put_rule(self, name, **kwargs): - if kwargs.get("ScheduleExpression") and kwargs.get("EventBusName") != "default": - raise ValidationException( - "ScheduleExpression is supported only on the default event bus." - ) - if name in self.rules: - self.update_rule(self.rules[name], **kwargs) - new_rule = self.rules[name] - else: - new_rule = Rule(name, self.region_name, **kwargs) - self.rules[new_rule.name] = new_rule - self.rules_order.append(new_rule.name) - return new_rule - def put_targets(self, name, event_bus_name, targets): # super simple ARN check invalid_arn = next( @@ -1342,13 +1438,12 @@ class EventsBackend(BaseBackend): rule_event_pattern = json.loads(event_pattern or "{}") rule_event_pattern["replay-name"] = [{"exists": False}] + rule_name = "Events-Archive-{}".format(name) rule = self.put_rule( - "Events-Archive-{}".format(name), - **{ - "EventPattern": json.dumps(rule_event_pattern), - "EventBusName": event_bus.name, - "ManagedBy": "prod.vhs.events.aws.internal", - }, + rule_name, + event_pattern=json.dumps(rule_event_pattern), + event_bus_name=event_bus.name, + managed_by="prod.vhs.events.aws.internal", ) self.put_targets( rule.name, diff --git a/moto/events/responses.py b/moto/events/responses.py index 497708673..b6afb3180 100644 --- a/moto/events/responses.py +++ b/moto/events/responses.py @@ -1,5 +1,4 @@ import json -import re from moto.core.responses import BaseResponse from moto.events import events_backends @@ -16,20 +15,6 @@ class EventsHandler(BaseResponse): """ return events_backends[self.region] - def _generate_rule_dict(self, rule): - return { - "Name": rule.name, - "Arn": rule.arn, - "EventPattern": str(rule.event_pattern), - "State": rule.state, - "Description": rule.description, - "ScheduleExpression": rule.schedule_exp, - "RoleArn": rule.role_arn, - "ManagedBy": rule.managed_by, - "EventBusName": rule.event_bus_name, - "CreatedBy": rule.created_by, - } - @property def request_params(self): if not hasattr(self, "_json_body"): @@ -61,6 +46,29 @@ class EventsHandler(BaseResponse): headers["status"] = status return json.dumps({"__type": type_, "message": message}), headers + def put_rule(self): + name = self._get_param("Name") + event_pattern = self._get_param("EventPattern") + scheduled_expression = self._get_param("ScheduleExpression") + state = self._get_param("State") + desc = self._get_param("Description") + role_arn = self._get_param("RoleArn") + event_bus_name = self._get_param("EventBusName") + tags = self._get_param("Tags") + + rule = self.events_backend.put_rule( + name, + scheduled_expression=scheduled_expression, + event_pattern=event_pattern, + state=state, + description=desc, + role_arn=role_arn, + event_bus_name=event_bus_name, + tags=tags, + ) + result = {"RuleArn": rule.arn} + return self._create_response(result) + def delete_rule(self): name = self._get_param("Name") @@ -78,13 +86,8 @@ class EventsHandler(BaseResponse): rule = self.events_backend.describe_rule(name) - if not rule: - return self.error( - "ResourceNotFoundException", "Rule " + name + " does not exist." - ) - - rule_dict = self._generate_rule_dict(rule) - return json.dumps(rule_dict), self.response_headers + result = rule.describe() + return self._create_response(result) def disable_rule(self): name = self._get_param("Name") @@ -138,7 +141,7 @@ class EventsHandler(BaseResponse): rules_obj = {"Rules": []} for rule in rules["Rules"]: - rules_obj["Rules"].append(self._generate_rule_dict(rule)) + rules_obj["Rules"].append(rule.describe()) if rules.get("NextToken"): rules_obj["NextToken"] = rules["NextToken"] @@ -177,48 +180,6 @@ class EventsHandler(BaseResponse): return json.dumps(response) - def put_rule(self): - name = self._get_param("Name") - event_pattern = self._get_param("EventPattern") - sched_exp = self._get_param("ScheduleExpression") - state = self._get_param("State") - desc = self._get_param("Description") - role_arn = self._get_param("RoleArn") - event_bus_name = self._get_param("EventBusName", "default") - - if event_pattern: - try: - json.loads(event_pattern) - except ValueError: - # Not quite as informative as the real error, but it'll work - # for now. - return self.error( - "InvalidEventPatternException", "Event pattern is not valid." - ) - - if sched_exp: - if not ( - re.match(r"^cron\(.*\)", sched_exp) - or re.match( - r"^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)", sched_exp - ) - ): - return self.error( - "ValidationException", "Parameter ScheduleExpression is not valid." - ) - - rule = self.events_backend.put_rule( - name, - ScheduleExpression=sched_exp, - EventPattern=event_pattern, - State=state, - Description=desc, - RoleArn=role_arn, - EventBusName=event_bus_name, - ) - - return json.dumps({"RuleArn": rule.arn}), self.response_headers - def put_targets(self): rule_name = self._get_param("Rule") event_bus_name = self._get_param("EventBusName", "default") diff --git a/tests/terraform-tests.success.txt b/tests/terraform-tests.success.txt index cb6445899..8ac0cef5d 100644 --- a/tests/terraform-tests.success.txt +++ b/tests/terraform-tests.success.txt @@ -10,6 +10,7 @@ TestAccAWSCloudWatchEventBus TestAccAWSCloudwatchEventBusPolicy TestAccAWSCloudWatchEventConnection TestAccAWSCloudWatchEventPermission +TestAccAWSCloudWatchEventRule TestAccAWSCloudwatchLogGroupDataSource TestAccAWSDataSourceCloudwatch TestAccAWSDataSourceElasticBeanstalkHostedZone diff --git a/tests/test_events/test_event_pattern.py b/tests/test_events/test_event_pattern.py index 612e49028..ceb303351 100644 --- a/tests/test_events/test_event_pattern.py +++ b/tests/test_events/test_event_pattern.py @@ -6,41 +6,45 @@ from moto.events.models import EventPattern def test_event_pattern_with_allowed_values_event_filter(): - pattern = EventPattern(json.dumps({"source": ["foo", "bar"]})) + pattern = EventPattern.load(json.dumps({"source": ["foo", "bar"]})) assert pattern.matches_event({"source": "foo"}) assert pattern.matches_event({"source": "bar"}) assert not pattern.matches_event({"source": "baz"}) def test_event_pattern_with_nested_event_filter(): - pattern = EventPattern(json.dumps({"detail": {"foo": ["bar"]}})) + pattern = EventPattern.load(json.dumps({"detail": {"foo": ["bar"]}})) assert pattern.matches_event({"detail": {"foo": "bar"}}) assert not pattern.matches_event({"detail": {"foo": "baz"}}) def test_event_pattern_with_exists_event_filter(): - foo_exists = EventPattern(json.dumps({"detail": {"foo": [{"exists": True}]}})) + foo_exists = EventPattern.load(json.dumps({"detail": {"foo": [{"exists": True}]}})) assert foo_exists.matches_event({"detail": {"foo": "bar"}}) assert not foo_exists.matches_event({"detail": {}}) # exists filters only match leaf nodes of an event assert not foo_exists.matches_event({"detail": {"foo": {"bar": "baz"}}}) - foo_not_exists = EventPattern(json.dumps({"detail": {"foo": [{"exists": False}]}})) + foo_not_exists = EventPattern.load( + json.dumps({"detail": {"foo": [{"exists": False}]}}) + ) assert not foo_not_exists.matches_event({"detail": {"foo": "bar"}}) assert foo_not_exists.matches_event({"detail": {}}) assert foo_not_exists.matches_event({"detail": {"foo": {"bar": "baz"}}}) - bar_exists = EventPattern(json.dumps({"detail": {"bar": [{"exists": True}]}})) + bar_exists = EventPattern.load(json.dumps({"detail": {"bar": [{"exists": True}]}})) assert not bar_exists.matches_event({"detail": {"foo": "bar"}}) assert not bar_exists.matches_event({"detail": {}}) - bar_not_exists = EventPattern(json.dumps({"detail": {"bar": [{"exists": False}]}})) + bar_not_exists = EventPattern.load( + json.dumps({"detail": {"bar": [{"exists": False}]}}) + ) assert bar_not_exists.matches_event({"detail": {"foo": "bar"}}) assert bar_not_exists.matches_event({"detail": {}}) def test_event_pattern_with_prefix_event_filter(): - pattern = EventPattern(json.dumps({"detail": {"foo": [{"prefix": "bar"}]}})) + pattern = EventPattern.load(json.dumps({"detail": {"foo": [{"prefix": "bar"}]}})) assert pattern.matches_event({"detail": {"foo": "bar"}}) assert pattern.matches_event({"detail": {"foo": "bar!"}}) assert not pattern.matches_event({"detail": {"foo": "ba"}}) @@ -59,7 +63,7 @@ def test_event_pattern_with_prefix_event_filter(): def test_event_pattern_with_single_numeric_event_filter( operator, compare_to, should_match, should_not_match ): - pattern = EventPattern( + pattern = EventPattern.load( json.dumps({"detail": {"foo": [{"numeric": [operator, compare_to]}]}}) ) for number in should_match: @@ -71,7 +75,7 @@ def test_event_pattern_with_single_numeric_event_filter( def test_event_pattern_with_multi_numeric_event_filter(): events = [{"detail": {"foo": number}} for number in range(5)] - one_or_two = EventPattern( + one_or_two = EventPattern.load( json.dumps({"detail": {"foo": [{"numeric": [">=", 1, "<", 3]}]}}) ) assert not one_or_two.matches_event(events[0]) @@ -80,7 +84,7 @@ def test_event_pattern_with_multi_numeric_event_filter(): assert not one_or_two.matches_event(events[3]) assert not one_or_two.matches_event(events[4]) - two_or_three = EventPattern( + two_or_three = EventPattern.load( json.dumps({"detail": {"foo": [{"numeric": [">", 1, "<=", 3]}]}}) ) assert not two_or_three.matches_event(events[0]) @@ -92,8 +96,8 @@ def test_event_pattern_with_multi_numeric_event_filter(): @pytest.mark.parametrize( "pattern, expected_str", - [('{"source": ["foo", "bar"]}', '{"source": ["foo", "bar"]}'), (None, ""),], + [('{"source": ["foo", "bar"]}', '{"source": ["foo", "bar"]}'), (None, None),], ) -def test_event_pattern_str(pattern, expected_str): - event_pattern = EventPattern(pattern) - assert str(event_pattern) == expected_str +def test_event_pattern_dump(pattern, expected_str): + event_pattern = EventPattern.load(pattern) + assert event_pattern.dump() == expected_str diff --git a/tests/test_events/test_events.py b/tests/test_events/test_events.py index 25ab1b160..0582125ac 100644 --- a/tests/test_events/test_events.py +++ b/tests/test_events/test_events.py @@ -1050,7 +1050,9 @@ def test_create_archive_error_invalid_event_pattern(): ex.operation_name.should.equal("CreateArchive") ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) ex.response["Error"]["Code"].should.contain("InvalidEventPatternException") - ex.response["Error"]["Message"].should.equal("Event pattern is not valid.") + ex.response["Error"]["Message"].should.equal( + "Event pattern is not valid. Reason: Invalid JSON" + ) @mock_events @@ -1080,7 +1082,9 @@ def test_create_archive_error_invalid_event_pattern_not_an_array(): ex.operation_name.should.equal("CreateArchive") ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) ex.response["Error"]["Code"].should.contain("InvalidEventPatternException") - ex.response["Error"]["Message"].should.equal("Event pattern is not valid.") + ex.response["Error"]["Message"].should.equal( + "Event pattern is not valid. Reason: 'key_6' must be an object or an array" + ) @mock_events @@ -1378,7 +1382,9 @@ def test_update_archive_error_invalid_event_pattern(): ex.operation_name.should.equal("UpdateArchive") ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) ex.response["Error"]["Code"].should.contain("InvalidEventPatternException") - ex.response["Error"]["Message"].should.equal("Event pattern is not valid.") + ex.response["Error"]["Message"].should.equal( + "Event pattern is not valid. Reason: Invalid JSON" + ) @mock_events @@ -2399,7 +2405,6 @@ def test_delete_connection_success(): response = client.delete_connection(Name=conn_name) # Then - expected_arn = f"arn:aws:events:eu-central-1:{ACCOUNT_ID}:connection/{conn_name}/" assert response["ConnectionArn"] == created_connection["ConnectionArn"] assert response["ConnectionState"] == created_connection["ConnectionState"] assert response["CreationTime"] == created_connection["CreationTime"]