From 2b9c98895c23a131ad271f92cac165b0894ac749 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Mon, 20 Mar 2023 13:47:22 -0100 Subject: [PATCH] Events: Separate Rules by the EventBus they're assigned to (#6091) --- moto/events/models.py | 119 +++++++++++------- moto/events/notifications.py | 15 +-- moto/events/responses.py | 35 ++++-- .../terraform-tests.success.txt | 12 ++ tests/test_events/test_events.py | 48 +++++++ 5 files changed, 165 insertions(+), 64 deletions(-) diff --git a/moto/events/models.py b/moto/events/models.py index 9ef96c647..f2908c22e 100644 --- a/moto/events/models.py +++ b/moto/events/models.py @@ -96,7 +96,7 @@ class Rule(CloudFormationModel): def delete(self, account_id: str, region_name: str) -> None: event_backend = events_backends[account_id][region_name] - event_backend.delete_rule(name=self.name) + event_backend.delete_rule(name=self.name, event_bus_arn=self.event_bus_name) def put_targets(self, targets: List[Dict[str, Any]]) -> None: # Not testing for valid ARNs. @@ -113,11 +113,7 @@ class Rule(CloudFormationModel): if index is not None: self.targets.pop(index) - def send_to_targets(self, event_bus_name: str, event: Dict[str, Any]) -> None: - event_bus_name = event_bus_name.split("/")[-1] - if event_bus_name != self.event_bus_name.split("/")[-1]: - return - + def send_to_targets(self, event: Dict[str, Any]) -> None: if not self.event_pattern.matches_event(event): return @@ -277,7 +273,7 @@ class Rule(CloudFormationModel): state = properties.get("State") desc = properties.get("Description") role_arn = properties.get("RoleArn") - event_bus_name = properties.get("EventBusName") + event_bus_arn = properties.get("EventBusName") tags = properties.get("Tags") backend = events_backends[account_id][region_name] @@ -288,7 +284,7 @@ class Rule(CloudFormationModel): state=state, description=desc, role_arn=role_arn, - event_bus_name=event_bus_name, + event_bus_arn=event_bus_arn, tags=tags, ) @@ -315,7 +311,9 @@ class Rule(CloudFormationModel): region_name: str, ) -> None: event_backend = events_backends[account_id][region_name] - event_backend.delete_rule(resource_name) + properties = cloudformation_json["Properties"] + event_bus_arn = properties.get("EventBusName") + event_backend.delete_rule(resource_name, event_bus_arn) def describe(self) -> Dict[str, Any]: attributes = { @@ -351,6 +349,7 @@ class EventBus(CloudFormationModel): self.tags = tags or [] self._statements: Dict[str, EventBusPolicyStatement] = {} + self.rules: Dict[str, Rule] = OrderedDict() @property def policy(self) -> Optional[str]: @@ -738,9 +737,9 @@ class Replay(BaseModel): for event in archive.events: event_backend = events_backends[self.account_id][self.region] - for rule in event_backend.rules.values(): + event_bus = event_backend.describe_event_bus(event_bus_name) + for rule in event_bus.rules.values(): rule.send_to_targets( - event_bus_name, dict( event, **{"id": str(random.uuid4()), "replay-name": self.name} # type: ignore ), @@ -996,7 +995,6 @@ class EventsBackend(BaseBackend): def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.rules: Dict[str, Rule] = OrderedDict() self.next_tokens: Dict[str, int] = {} self.event_buses: Dict[str, EventBus] = {} self.event_sources: Dict[str, str] = {} @@ -1070,7 +1068,7 @@ class EventsBackend(BaseBackend): self, name: str, description: Optional[str] = None, - event_bus_name: Optional[str] = None, + event_bus_arn: Optional[str] = None, event_pattern: Optional[str] = None, role_arn: Optional[str] = None, scheduled_expression: Optional[str] = None, @@ -1078,7 +1076,7 @@ class EventsBackend(BaseBackend): managed_by: Optional[str] = None, tags: Optional[List[Dict[str, str]]] = None, ) -> Rule: - event_bus_name = event_bus_name or "default" + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) if not event_pattern and not scheduled_expression: raise JsonRESTError( @@ -1098,7 +1096,8 @@ class EventsBackend(BaseBackend): ): raise ValidationException("Parameter ScheduleExpression is not valid.") - existing_rule = self.rules.get(name) + event_bus = self._get_event_bus(event_bus_name) + existing_rule = event_bus.rules.get(name) targets = existing_rule.targets if existing_rule else list() rule = Rule( name, @@ -1113,15 +1112,22 @@ class EventsBackend(BaseBackend): managed_by, targets=targets, ) - self.rules[name] = rule + event_bus.rules[name] = rule if tags: self.tagger.tag_resource(rule.arn, tags) return rule - def delete_rule(self, name: str) -> None: - rule = self.rules.get(name) + def _normalize_event_bus_arn(self, event_bus_arn: Optional[str]) -> str: + if event_bus_arn is None: + return "default" + return event_bus_arn.split("/")[-1] + + def delete_rule(self, name: str, event_bus_arn: Optional[str]) -> None: + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) + rule = event_bus.rules.get(name) if not rule: return if len(rule.targets) > 0: @@ -1130,33 +1136,41 @@ class EventsBackend(BaseBackend): arn = rule.arn if self.tagger.has_tags(arn): self.tagger.delete_all_tags_for_resource(arn) - self.rules.pop(name) + event_bus.rules.pop(name) - def describe_rule(self, name: str) -> Rule: - rule = self.rules.get(name) + def describe_rule(self, name: str, event_bus_arn: Optional[str]) -> Rule: + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) + rule = event_bus.rules.get(name) if not rule: raise ResourceNotFoundException(f"Rule {name} does not exist.") return rule - def disable_rule(self, name: str) -> bool: - if name in self.rules: - self.rules[name].disable() + def disable_rule(self, name: str, event_bus_arn: Optional[str]) -> bool: + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) + if name in event_bus.rules: + event_bus.rules[name].disable() return True return False - def enable_rule(self, name: str) -> bool: - if name in self.rules: - self.rules[name].enable() + def enable_rule(self, name: str, event_bus_arn: Optional[str]) -> bool: + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) + if name in event_bus.rules: + event_bus.rules[name].enable() return True return False @paginate(pagination_model=PAGINATION_MODEL) - def list_rule_names_by_target(self, target_arn: str) -> List[Rule]: # type: ignore[misc] + def list_rule_names_by_target(self, target_arn: str, event_bus_arn: Optional[str]) -> List[Rule]: # type: ignore[misc] + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) matching_rules = [] - for _, rule in self.rules.items(): + for _, rule in event_bus.rules.items(): for target in rule.targets: if target["Arn"] == target_arn: matching_rules.append(rule) @@ -1164,7 +1178,9 @@ class EventsBackend(BaseBackend): return matching_rules @paginate(pagination_model=PAGINATION_MODEL) - def list_rules(self, prefix: Optional[str] = None) -> List[Rule]: # type: ignore[misc] + def list_rules(self, prefix: Optional[str] = None, event_bus_arn: Optional[str] = None) -> List[Rule]: # type: ignore[misc] + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) match_string = ".*" if prefix is not None: match_string = "^" + prefix + match_string @@ -1173,7 +1189,7 @@ class EventsBackend(BaseBackend): matching_rules = [] - for name, rule in self.rules.items(): + for name, rule in event_bus.rules.items(): if match_regex.match(name): matching_rules.append(rule) @@ -1182,12 +1198,15 @@ class EventsBackend(BaseBackend): def list_targets_by_rule( self, rule_id: str, + event_bus_arn: Optional[str], next_token: Optional[str] = None, limit: Optional[str] = None, ) -> Dict[str, Any]: # We'll let a KeyError exception be thrown for response to handle if # rule doesn't exist. - rule = self.rules[rule_id] + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) + rule = event_bus.rules[rule_id] start_index, end_index, new_next_token = self._process_token_and_limits( len(rule.targets), next_token, limit @@ -1206,8 +1225,10 @@ class EventsBackend(BaseBackend): return return_obj def put_targets( - self, name: str, event_bus_name: str, targets: List[Dict[str, Any]] + self, name: str, event_bus_arn: Optional[str], targets: List[Dict[str, Any]] ) -> None: + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) # super simple ARN check invalid_arn = next( ( @@ -1234,7 +1255,7 @@ class EventsBackend(BaseBackend): f"Parameter(s) SqsParameters must be specified for target: {target['Id']}." ) - rule = self.rules.get(name) + rule = event_bus.rules.get(name) if not rule: raise ResourceNotFoundException( @@ -1301,11 +1322,13 @@ class EventsBackend(BaseBackend): entries.append({"EventId": event_id}) # if 'EventBusName' is not especially set, it will be sent to the default one - event_bus_name = event.get("EventBusName", "default") + event_bus_name = self._normalize_event_bus_arn( + event.get("EventBusName") + ) - for rule in self.rules.values(): + event_bus = self.describe_event_bus(event_bus_name) + for rule in event_bus.rules.values(): rule.send_to_targets( - event_bus_name, { "version": "0", "id": event_id, @@ -1321,8 +1344,12 @@ class EventsBackend(BaseBackend): return entries - def remove_targets(self, name: str, event_bus_name: str, ids: List[str]) -> None: - rule = self.rules.get(name) + def remove_targets( + self, name: str, event_bus_arn: Optional[str], ids: List[str] + ) -> None: + event_bus_name = self._normalize_event_bus_arn(event_bus_arn) + event_bus = self._get_event_bus(event_bus_name) + rule = event_bus.rules.get(name) if not rule: raise ResourceNotFoundException( @@ -1499,8 +1526,8 @@ class EventsBackend(BaseBackend): def list_tags_for_resource(self, arn: str) -> Dict[str, List[Dict[str, str]]]: name = arn.split("/")[-1] - registries = [self.rules, self.event_buses] - for registry in registries: + rules = [bus.rules for bus in self.event_buses.values()] + for registry in rules + [self.event_buses]: if name in registry: # type: ignore return self.tagger.list_tags_for_resource(registry[name].arn) # type: ignore raise ResourceNotFoundException( @@ -1509,8 +1536,8 @@ class EventsBackend(BaseBackend): def tag_resource(self, arn: str, tags: List[Dict[str, str]]) -> None: name = arn.split("/")[-1] - registries = [self.rules, self.event_buses] - for registry in registries: + rules = [bus.rules for bus in self.event_buses.values()] + for registry in rules + [self.event_buses]: if name in registry: # type: ignore self.tagger.tag_resource(registry[name].arn, tags) # type: ignore return @@ -1520,8 +1547,8 @@ class EventsBackend(BaseBackend): def untag_resource(self, arn: str, tag_names: List[str]) -> None: name = arn.split("/")[-1] - registries = [self.rules, self.event_buses] - for registry in registries: + rules = [bus.rules for bus in self.event_buses.values()] + for registry in rules + [self.event_buses]: if name in registry: # type: ignore self.tagger.untag_resource_using_names(registry[name].arn, tag_names) # type: ignore return @@ -1566,7 +1593,7 @@ class EventsBackend(BaseBackend): rule = self.put_rule( rule_name, event_pattern=json.dumps(rule_event_pattern), - event_bus_name=event_bus.name, + event_bus_arn=event_bus.name, managed_by="prod.vhs.events.aws.internal", ) self.put_targets( diff --git a/moto/events/notifications.py b/moto/events/notifications.py index dad88d376..0883dd147 100644 --- a/moto/events/notifications.py +++ b/moto/events/notifications.py @@ -43,13 +43,14 @@ def _send_safe_notification( for account_id, account in events_backends.items(): for backend in account.values(): applicable_targets = [] - for rule in backend.rules.values(): - if rule.state != "ENABLED": - continue - pattern = rule.event_pattern.get_pattern() - if source in pattern.get("source", []): - if event_name in pattern.get("detail", {}).get("eventName", []): - applicable_targets.extend(rule.targets) + for event_bus in backend.event_buses.values(): + for rule in event_bus.rules.values(): + if rule.state != "ENABLED": + continue + pattern = rule.event_pattern.get_pattern() + if source in pattern.get("source", []): + if event_name in pattern.get("detail", {}).get("eventName", []): + applicable_targets.extend(rule.targets) for target in applicable_targets: if target.get("Arn", "").startswith("arn:aws:lambda"): diff --git a/moto/events/responses.py b/moto/events/responses.py index c1b25c9ba..17375fde1 100644 --- a/moto/events/responses.py +++ b/moto/events/responses.py @@ -40,7 +40,7 @@ class EventsHandler(BaseResponse): state = self._get_param("State") desc = self._get_param("Description") role_arn = self._get_param("RoleArn") - event_bus_name = self._get_param("EventBusName") + event_bus_arn = self._get_param("EventBusName") tags = self._get_param("Tags") rule = self.events_backend.put_rule( @@ -50,7 +50,7 @@ class EventsHandler(BaseResponse): state=state, description=desc, role_arn=role_arn, - event_bus_name=event_bus_name, + event_bus_arn=event_bus_arn, tags=tags, ) result = {"RuleArn": rule.arn} @@ -58,31 +58,34 @@ class EventsHandler(BaseResponse): def delete_rule(self) -> Tuple[str, Dict[str, Any]]: name = self._get_param("Name") + event_bus_arn = self._get_param("EventBusName") if not name: return self.error("ValidationException", "Parameter Name is required.") - self.events_backend.delete_rule(name) + self.events_backend.delete_rule(name, event_bus_arn) return "", self.response_headers def describe_rule(self) -> Tuple[str, Dict[str, Any]]: name = self._get_param("Name") + event_bus_arn = self._get_param("EventBusName") if not name: return self.error("ValidationException", "Parameter Name is required.") - rule = self.events_backend.describe_rule(name) + rule = self.events_backend.describe_rule(name, event_bus_arn) result = rule.describe() return self._create_response(result) def disable_rule(self) -> Tuple[str, Dict[str, Any]]: name = self._get_param("Name") + event_bus_arn = self._get_param("EventBusName") if not name: return self.error("ValidationException", "Parameter Name is required.") - if not self.events_backend.disable_rule(name): + if not self.events_backend.disable_rule(name, event_bus_arn): return self.error( "ResourceNotFoundException", "Rule " + name + " does not exist." ) @@ -91,11 +94,12 @@ class EventsHandler(BaseResponse): def enable_rule(self) -> Tuple[str, Dict[str, Any]]: name = self._get_param("Name") + event_bus_arn = self._get_param("EventBusName") if not name: return self.error("ValidationException", "Parameter Name is required.") - if not self.events_backend.enable_rule(name): + if not self.events_backend.enable_rule(name, event_bus_arn): return self.error( "ResourceNotFoundException", "Rule " + name + " does not exist." ) @@ -107,6 +111,7 @@ class EventsHandler(BaseResponse): def list_rule_names_by_target(self) -> Tuple[str, Dict[str, Any]]: target_arn = self._get_param("TargetArn") + event_bus_arn = self._get_param("EventBusName") next_token = self._get_param("NextToken") limit = self._get_param("Limit") @@ -114,7 +119,10 @@ class EventsHandler(BaseResponse): return self.error("ValidationException", "Parameter TargetArn is required.") rules, token = self.events_backend.list_rule_names_by_target( - target_arn=target_arn, next_token=next_token, limit=limit + target_arn=target_arn, + event_bus_arn=event_bus_arn, + next_token=next_token, + limit=limit, ) res = {"RuleNames": [rule.name for rule in rules], "NextToken": token} @@ -123,11 +131,15 @@ class EventsHandler(BaseResponse): def list_rules(self) -> Tuple[str, Dict[str, Any]]: prefix = self._get_param("NamePrefix") + event_bus_arn = self._get_param("EventBusName") next_token = self._get_param("NextToken") limit = self._get_param("Limit") rules, token = self.events_backend.list_rules( - prefix=prefix, next_token=next_token, limit=limit + prefix=prefix, + event_bus_arn=event_bus_arn, + next_token=next_token, + limit=limit, ) rules_obj = { "Rules": [rule.describe() for rule in rules], @@ -138,6 +150,7 @@ class EventsHandler(BaseResponse): def list_targets_by_rule(self) -> Tuple[str, Dict[str, Any]]: rule_name = self._get_param("Rule") + event_bus_arn = self._get_param("EventBusName") next_token = self._get_param("NextToken") limit = self._get_param("Limit") @@ -146,7 +159,7 @@ class EventsHandler(BaseResponse): try: targets = self.events_backend.list_targets_by_rule( - rule_name, next_token, limit + rule_name, event_bus_arn, next_token, limit ) except KeyError: return self.error( @@ -170,7 +183,7 @@ class EventsHandler(BaseResponse): def put_targets(self) -> Tuple[str, Dict[str, Any]]: rule_name = self._get_param("Rule") - event_bus_name = self._get_param("EventBusName", "default") + event_bus_name = self._get_param("EventBusName") targets = self._get_param("Targets") self.events_backend.put_targets(rule_name, event_bus_name, targets) @@ -182,7 +195,7 @@ class EventsHandler(BaseResponse): def remove_targets(self) -> Tuple[str, Dict[str, Any]]: rule_name = self._get_param("Rule") - event_bus_name = self._get_param("EventBusName", "default") + event_bus_name = self._get_param("EventBusName") ids = self._get_param("Ids") self.events_backend.remove_targets(rule_name, event_bus_name, ids) diff --git a/tests/terraformtests/terraform-tests.success.txt b/tests/terraformtests/terraform-tests.success.txt index 0e4e3f71d..62ccd20f5 100644 --- a/tests/terraformtests/terraform-tests.success.txt +++ b/tests/terraformtests/terraform-tests.success.txt @@ -265,6 +265,18 @@ events: - TestAccEventsConnection - TestAccEventsConnectionDataSource - TestAccEventsPermission + - TestAccEventsRule + - TestAccEventsTarget_basic + - TestAccEventsTarget_batch + - TestAccEventsTarget_disappears + - TestAccEventsTarget_eventBusName + - TestAccEventsTarget_ecs + - TestAccEventsTarget_eventBusARN + - TestAccEventsTarget_full + - TestAccEventsTarget_generatedTargetID + - TestAccEventsTarget_inputTransformer + - TestAccEventsTarget_kinesis + - TestAccEventsTarget_ssmDocument firehose: - TestAccFirehoseDeliveryStreamDataSource_basic - TestAccFirehoseDeliveryStream_basic diff --git a/tests/test_events/test_events.py b/tests/test_events/test_events.py index acd00edbc..4f896aa50 100644 --- a/tests/test_events/test_events.py +++ b/tests/test_events/test_events.py @@ -105,6 +105,20 @@ def test_put_rule(): rules[0]["State"].should.equal("ENABLED") +@mock_events +def test_put_rule__where_event_bus_name_is_arn(): + client = boto3.client("events", "us-west-2") + event_bus_name = "test-bus" + event_bus_arn = client.create_event_bus(Name=event_bus_name)["EventBusArn"] + + rule_arn = client.put_rule( + Name="my-event", + EventPattern='{"source": ["test-source"]}', + EventBusName=event_bus_arn, + )["RuleArn"] + assert rule_arn == f"arn:aws:events:us-west-2:{ACCOUNT_ID}:rule/test-bus/my-event" + + @mock_events def test_put_rule_error_schedule_expression_custom_event_bus(): # given @@ -342,6 +356,40 @@ def test_list_targets_by_rule(): assert len(targets["Targets"]) == len(expected_targets) +@mock_events +def test_list_targets_by_rule_for_different_event_bus(): + client = generate_environment() + + client.create_event_bus(Name="newEventBus") + + client.put_rule(Name="test1", EventBusName="newEventBus", EventPattern="{}") + client.put_targets( + Rule="test1", + EventBusName="newEventBus", + Targets=[ + { + "Id": "newtarget", + "Arn": "arn:", + } + ], + ) + + # Total targets with this rule is 7, but, from the docs: + # If you omit [the eventBusName-parameter], the default event bus is used. + targets = client.list_targets_by_rule(Rule="test1")["Targets"] + assert len([t["Id"] for t in targets]) == 6 + + targets = client.list_targets_by_rule(Rule="test1", EventBusName="default")[ + "Targets" + ] + assert len([t["Id"] for t in targets]) == 6 + + targets = client.list_targets_by_rule(Rule="test1", EventBusName="newEventBus")[ + "Targets" + ] + assert [t["Id"] for t in targets] == ["newtarget"] + + @mock_events def test_remove_targets(): rule_name = get_random_rule()["Name"]