Events: Separate Rules by the EventBus they're assigned to (#6091)

This commit is contained in:
Bert Blommers 2023-03-20 13:47:22 -01:00 committed by GitHub
parent 86a00f1e4b
commit 2b9c98895c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 165 additions and 64 deletions

View File

@ -96,7 +96,7 @@ class Rule(CloudFormationModel):
def delete(self, account_id: str, region_name: str) -> None:
event_backend = events_backends[account_id][region_name]
event_backend.delete_rule(name=self.name)
event_backend.delete_rule(name=self.name, event_bus_arn=self.event_bus_name)
def put_targets(self, targets: List[Dict[str, Any]]) -> None:
# Not testing for valid ARNs.
@ -113,11 +113,7 @@ class Rule(CloudFormationModel):
if index is not None:
self.targets.pop(index)
def send_to_targets(self, event_bus_name: str, event: Dict[str, Any]) -> None:
event_bus_name = event_bus_name.split("/")[-1]
if event_bus_name != self.event_bus_name.split("/")[-1]:
return
def send_to_targets(self, event: Dict[str, Any]) -> None:
if not self.event_pattern.matches_event(event):
return
@ -277,7 +273,7 @@ class Rule(CloudFormationModel):
state = properties.get("State")
desc = properties.get("Description")
role_arn = properties.get("RoleArn")
event_bus_name = properties.get("EventBusName")
event_bus_arn = properties.get("EventBusName")
tags = properties.get("Tags")
backend = events_backends[account_id][region_name]
@ -288,7 +284,7 @@ class Rule(CloudFormationModel):
state=state,
description=desc,
role_arn=role_arn,
event_bus_name=event_bus_name,
event_bus_arn=event_bus_arn,
tags=tags,
)
@ -315,7 +311,9 @@ class Rule(CloudFormationModel):
region_name: str,
) -> None:
event_backend = events_backends[account_id][region_name]
event_backend.delete_rule(resource_name)
properties = cloudformation_json["Properties"]
event_bus_arn = properties.get("EventBusName")
event_backend.delete_rule(resource_name, event_bus_arn)
def describe(self) -> Dict[str, Any]:
attributes = {
@ -351,6 +349,7 @@ class EventBus(CloudFormationModel):
self.tags = tags or []
self._statements: Dict[str, EventBusPolicyStatement] = {}
self.rules: Dict[str, Rule] = OrderedDict()
@property
def policy(self) -> Optional[str]:
@ -738,9 +737,9 @@ class Replay(BaseModel):
for event in archive.events:
event_backend = events_backends[self.account_id][self.region]
for rule in event_backend.rules.values():
event_bus = event_backend.describe_event_bus(event_bus_name)
for rule in event_bus.rules.values():
rule.send_to_targets(
event_bus_name,
dict(
event, **{"id": str(random.uuid4()), "replay-name": self.name} # type: ignore
),
@ -996,7 +995,6 @@ class EventsBackend(BaseBackend):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.rules: Dict[str, Rule] = OrderedDict()
self.next_tokens: Dict[str, int] = {}
self.event_buses: Dict[str, EventBus] = {}
self.event_sources: Dict[str, str] = {}
@ -1070,7 +1068,7 @@ class EventsBackend(BaseBackend):
self,
name: str,
description: Optional[str] = None,
event_bus_name: Optional[str] = None,
event_bus_arn: Optional[str] = None,
event_pattern: Optional[str] = None,
role_arn: Optional[str] = None,
scheduled_expression: Optional[str] = None,
@ -1078,7 +1076,7 @@ class EventsBackend(BaseBackend):
managed_by: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
) -> Rule:
event_bus_name = event_bus_name or "default"
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
if not event_pattern and not scheduled_expression:
raise JsonRESTError(
@ -1098,7 +1096,8 @@ class EventsBackend(BaseBackend):
):
raise ValidationException("Parameter ScheduleExpression is not valid.")
existing_rule = self.rules.get(name)
event_bus = self._get_event_bus(event_bus_name)
existing_rule = event_bus.rules.get(name)
targets = existing_rule.targets if existing_rule else list()
rule = Rule(
name,
@ -1113,15 +1112,22 @@ class EventsBackend(BaseBackend):
managed_by,
targets=targets,
)
self.rules[name] = rule
event_bus.rules[name] = rule
if tags:
self.tagger.tag_resource(rule.arn, tags)
return rule
def delete_rule(self, name: str) -> None:
rule = self.rules.get(name)
def _normalize_event_bus_arn(self, event_bus_arn: Optional[str]) -> str:
if event_bus_arn is None:
return "default"
return event_bus_arn.split("/")[-1]
def delete_rule(self, name: str, event_bus_arn: Optional[str]) -> None:
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
rule = event_bus.rules.get(name)
if not rule:
return
if len(rule.targets) > 0:
@ -1130,33 +1136,41 @@ class EventsBackend(BaseBackend):
arn = rule.arn
if self.tagger.has_tags(arn):
self.tagger.delete_all_tags_for_resource(arn)
self.rules.pop(name)
event_bus.rules.pop(name)
def describe_rule(self, name: str) -> Rule:
rule = self.rules.get(name)
def describe_rule(self, name: str, event_bus_arn: Optional[str]) -> Rule:
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
rule = event_bus.rules.get(name)
if not rule:
raise ResourceNotFoundException(f"Rule {name} does not exist.")
return rule
def disable_rule(self, name: str) -> bool:
if name in self.rules:
self.rules[name].disable()
def disable_rule(self, name: str, event_bus_arn: Optional[str]) -> bool:
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
if name in event_bus.rules:
event_bus.rules[name].disable()
return True
return False
def enable_rule(self, name: str) -> bool:
if name in self.rules:
self.rules[name].enable()
def enable_rule(self, name: str, event_bus_arn: Optional[str]) -> bool:
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
if name in event_bus.rules:
event_bus.rules[name].enable()
return True
return False
@paginate(pagination_model=PAGINATION_MODEL)
def list_rule_names_by_target(self, target_arn: str) -> List[Rule]: # type: ignore[misc]
def list_rule_names_by_target(self, target_arn: str, event_bus_arn: Optional[str]) -> List[Rule]: # type: ignore[misc]
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
matching_rules = []
for _, rule in self.rules.items():
for _, rule in event_bus.rules.items():
for target in rule.targets:
if target["Arn"] == target_arn:
matching_rules.append(rule)
@ -1164,7 +1178,9 @@ class EventsBackend(BaseBackend):
return matching_rules
@paginate(pagination_model=PAGINATION_MODEL)
def list_rules(self, prefix: Optional[str] = None) -> List[Rule]: # type: ignore[misc]
def list_rules(self, prefix: Optional[str] = None, event_bus_arn: Optional[str] = None) -> List[Rule]: # type: ignore[misc]
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
match_string = ".*"
if prefix is not None:
match_string = "^" + prefix + match_string
@ -1173,7 +1189,7 @@ class EventsBackend(BaseBackend):
matching_rules = []
for name, rule in self.rules.items():
for name, rule in event_bus.rules.items():
if match_regex.match(name):
matching_rules.append(rule)
@ -1182,12 +1198,15 @@ class EventsBackend(BaseBackend):
def list_targets_by_rule(
self,
rule_id: str,
event_bus_arn: Optional[str],
next_token: Optional[str] = None,
limit: Optional[str] = None,
) -> Dict[str, Any]:
# We'll let a KeyError exception be thrown for response to handle if
# rule doesn't exist.
rule = self.rules[rule_id]
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
rule = event_bus.rules[rule_id]
start_index, end_index, new_next_token = self._process_token_and_limits(
len(rule.targets), next_token, limit
@ -1206,8 +1225,10 @@ class EventsBackend(BaseBackend):
return return_obj
def put_targets(
self, name: str, event_bus_name: str, targets: List[Dict[str, Any]]
self, name: str, event_bus_arn: Optional[str], targets: List[Dict[str, Any]]
) -> None:
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
# super simple ARN check
invalid_arn = next(
(
@ -1234,7 +1255,7 @@ class EventsBackend(BaseBackend):
f"Parameter(s) SqsParameters must be specified for target: {target['Id']}."
)
rule = self.rules.get(name)
rule = event_bus.rules.get(name)
if not rule:
raise ResourceNotFoundException(
@ -1301,11 +1322,13 @@ class EventsBackend(BaseBackend):
entries.append({"EventId": event_id})
# if 'EventBusName' is not especially set, it will be sent to the default one
event_bus_name = event.get("EventBusName", "default")
event_bus_name = self._normalize_event_bus_arn(
event.get("EventBusName")
)
for rule in self.rules.values():
event_bus = self.describe_event_bus(event_bus_name)
for rule in event_bus.rules.values():
rule.send_to_targets(
event_bus_name,
{
"version": "0",
"id": event_id,
@ -1321,8 +1344,12 @@ class EventsBackend(BaseBackend):
return entries
def remove_targets(self, name: str, event_bus_name: str, ids: List[str]) -> None:
rule = self.rules.get(name)
def remove_targets(
self, name: str, event_bus_arn: Optional[str], ids: List[str]
) -> None:
event_bus_name = self._normalize_event_bus_arn(event_bus_arn)
event_bus = self._get_event_bus(event_bus_name)
rule = event_bus.rules.get(name)
if not rule:
raise ResourceNotFoundException(
@ -1499,8 +1526,8 @@ class EventsBackend(BaseBackend):
def list_tags_for_resource(self, arn: str) -> Dict[str, List[Dict[str, str]]]:
name = arn.split("/")[-1]
registries = [self.rules, self.event_buses]
for registry in registries:
rules = [bus.rules for bus in self.event_buses.values()]
for registry in rules + [self.event_buses]:
if name in registry: # type: ignore
return self.tagger.list_tags_for_resource(registry[name].arn) # type: ignore
raise ResourceNotFoundException(
@ -1509,8 +1536,8 @@ class EventsBackend(BaseBackend):
def tag_resource(self, arn: str, tags: List[Dict[str, str]]) -> None:
name = arn.split("/")[-1]
registries = [self.rules, self.event_buses]
for registry in registries:
rules = [bus.rules for bus in self.event_buses.values()]
for registry in rules + [self.event_buses]:
if name in registry: # type: ignore
self.tagger.tag_resource(registry[name].arn, tags) # type: ignore
return
@ -1520,8 +1547,8 @@ class EventsBackend(BaseBackend):
def untag_resource(self, arn: str, tag_names: List[str]) -> None:
name = arn.split("/")[-1]
registries = [self.rules, self.event_buses]
for registry in registries:
rules = [bus.rules for bus in self.event_buses.values()]
for registry in rules + [self.event_buses]:
if name in registry: # type: ignore
self.tagger.untag_resource_using_names(registry[name].arn, tag_names) # type: ignore
return
@ -1566,7 +1593,7 @@ class EventsBackend(BaseBackend):
rule = self.put_rule(
rule_name,
event_pattern=json.dumps(rule_event_pattern),
event_bus_name=event_bus.name,
event_bus_arn=event_bus.name,
managed_by="prod.vhs.events.aws.internal",
)
self.put_targets(

View File

@ -43,13 +43,14 @@ def _send_safe_notification(
for account_id, account in events_backends.items():
for backend in account.values():
applicable_targets = []
for rule in backend.rules.values():
if rule.state != "ENABLED":
continue
pattern = rule.event_pattern.get_pattern()
if source in pattern.get("source", []):
if event_name in pattern.get("detail", {}).get("eventName", []):
applicable_targets.extend(rule.targets)
for event_bus in backend.event_buses.values():
for rule in event_bus.rules.values():
if rule.state != "ENABLED":
continue
pattern = rule.event_pattern.get_pattern()
if source in pattern.get("source", []):
if event_name in pattern.get("detail", {}).get("eventName", []):
applicable_targets.extend(rule.targets)
for target in applicable_targets:
if target.get("Arn", "").startswith("arn:aws:lambda"):

View File

@ -40,7 +40,7 @@ class EventsHandler(BaseResponse):
state = self._get_param("State")
desc = self._get_param("Description")
role_arn = self._get_param("RoleArn")
event_bus_name = self._get_param("EventBusName")
event_bus_arn = self._get_param("EventBusName")
tags = self._get_param("Tags")
rule = self.events_backend.put_rule(
@ -50,7 +50,7 @@ class EventsHandler(BaseResponse):
state=state,
description=desc,
role_arn=role_arn,
event_bus_name=event_bus_name,
event_bus_arn=event_bus_arn,
tags=tags,
)
result = {"RuleArn": rule.arn}
@ -58,31 +58,34 @@ class EventsHandler(BaseResponse):
def delete_rule(self) -> Tuple[str, Dict[str, Any]]:
name = self._get_param("Name")
event_bus_arn = self._get_param("EventBusName")
if not name:
return self.error("ValidationException", "Parameter Name is required.")
self.events_backend.delete_rule(name)
self.events_backend.delete_rule(name, event_bus_arn)
return "", self.response_headers
def describe_rule(self) -> Tuple[str, Dict[str, Any]]:
name = self._get_param("Name")
event_bus_arn = self._get_param("EventBusName")
if not name:
return self.error("ValidationException", "Parameter Name is required.")
rule = self.events_backend.describe_rule(name)
rule = self.events_backend.describe_rule(name, event_bus_arn)
result = rule.describe()
return self._create_response(result)
def disable_rule(self) -> Tuple[str, Dict[str, Any]]:
name = self._get_param("Name")
event_bus_arn = self._get_param("EventBusName")
if not name:
return self.error("ValidationException", "Parameter Name is required.")
if not self.events_backend.disable_rule(name):
if not self.events_backend.disable_rule(name, event_bus_arn):
return self.error(
"ResourceNotFoundException", "Rule " + name + " does not exist."
)
@ -91,11 +94,12 @@ class EventsHandler(BaseResponse):
def enable_rule(self) -> Tuple[str, Dict[str, Any]]:
name = self._get_param("Name")
event_bus_arn = self._get_param("EventBusName")
if not name:
return self.error("ValidationException", "Parameter Name is required.")
if not self.events_backend.enable_rule(name):
if not self.events_backend.enable_rule(name, event_bus_arn):
return self.error(
"ResourceNotFoundException", "Rule " + name + " does not exist."
)
@ -107,6 +111,7 @@ class EventsHandler(BaseResponse):
def list_rule_names_by_target(self) -> Tuple[str, Dict[str, Any]]:
target_arn = self._get_param("TargetArn")
event_bus_arn = self._get_param("EventBusName")
next_token = self._get_param("NextToken")
limit = self._get_param("Limit")
@ -114,7 +119,10 @@ class EventsHandler(BaseResponse):
return self.error("ValidationException", "Parameter TargetArn is required.")
rules, token = self.events_backend.list_rule_names_by_target(
target_arn=target_arn, next_token=next_token, limit=limit
target_arn=target_arn,
event_bus_arn=event_bus_arn,
next_token=next_token,
limit=limit,
)
res = {"RuleNames": [rule.name for rule in rules], "NextToken": token}
@ -123,11 +131,15 @@ class EventsHandler(BaseResponse):
def list_rules(self) -> Tuple[str, Dict[str, Any]]:
prefix = self._get_param("NamePrefix")
event_bus_arn = self._get_param("EventBusName")
next_token = self._get_param("NextToken")
limit = self._get_param("Limit")
rules, token = self.events_backend.list_rules(
prefix=prefix, next_token=next_token, limit=limit
prefix=prefix,
event_bus_arn=event_bus_arn,
next_token=next_token,
limit=limit,
)
rules_obj = {
"Rules": [rule.describe() for rule in rules],
@ -138,6 +150,7 @@ class EventsHandler(BaseResponse):
def list_targets_by_rule(self) -> Tuple[str, Dict[str, Any]]:
rule_name = self._get_param("Rule")
event_bus_arn = self._get_param("EventBusName")
next_token = self._get_param("NextToken")
limit = self._get_param("Limit")
@ -146,7 +159,7 @@ class EventsHandler(BaseResponse):
try:
targets = self.events_backend.list_targets_by_rule(
rule_name, next_token, limit
rule_name, event_bus_arn, next_token, limit
)
except KeyError:
return self.error(
@ -170,7 +183,7 @@ class EventsHandler(BaseResponse):
def put_targets(self) -> Tuple[str, Dict[str, Any]]:
rule_name = self._get_param("Rule")
event_bus_name = self._get_param("EventBusName", "default")
event_bus_name = self._get_param("EventBusName")
targets = self._get_param("Targets")
self.events_backend.put_targets(rule_name, event_bus_name, targets)
@ -182,7 +195,7 @@ class EventsHandler(BaseResponse):
def remove_targets(self) -> Tuple[str, Dict[str, Any]]:
rule_name = self._get_param("Rule")
event_bus_name = self._get_param("EventBusName", "default")
event_bus_name = self._get_param("EventBusName")
ids = self._get_param("Ids")
self.events_backend.remove_targets(rule_name, event_bus_name, ids)

View File

@ -265,6 +265,18 @@ events:
- TestAccEventsConnection
- TestAccEventsConnectionDataSource
- TestAccEventsPermission
- TestAccEventsRule
- TestAccEventsTarget_basic
- TestAccEventsTarget_batch
- TestAccEventsTarget_disappears
- TestAccEventsTarget_eventBusName
- TestAccEventsTarget_ecs
- TestAccEventsTarget_eventBusARN
- TestAccEventsTarget_full
- TestAccEventsTarget_generatedTargetID
- TestAccEventsTarget_inputTransformer
- TestAccEventsTarget_kinesis
- TestAccEventsTarget_ssmDocument
firehose:
- TestAccFirehoseDeliveryStreamDataSource_basic
- TestAccFirehoseDeliveryStream_basic

View File

@ -105,6 +105,20 @@ def test_put_rule():
rules[0]["State"].should.equal("ENABLED")
@mock_events
def test_put_rule__where_event_bus_name_is_arn():
client = boto3.client("events", "us-west-2")
event_bus_name = "test-bus"
event_bus_arn = client.create_event_bus(Name=event_bus_name)["EventBusArn"]
rule_arn = client.put_rule(
Name="my-event",
EventPattern='{"source": ["test-source"]}',
EventBusName=event_bus_arn,
)["RuleArn"]
assert rule_arn == f"arn:aws:events:us-west-2:{ACCOUNT_ID}:rule/test-bus/my-event"
@mock_events
def test_put_rule_error_schedule_expression_custom_event_bus():
# given
@ -342,6 +356,40 @@ def test_list_targets_by_rule():
assert len(targets["Targets"]) == len(expected_targets)
@mock_events
def test_list_targets_by_rule_for_different_event_bus():
client = generate_environment()
client.create_event_bus(Name="newEventBus")
client.put_rule(Name="test1", EventBusName="newEventBus", EventPattern="{}")
client.put_targets(
Rule="test1",
EventBusName="newEventBus",
Targets=[
{
"Id": "newtarget",
"Arn": "arn:",
}
],
)
# Total targets with this rule is 7, but, from the docs:
# If you omit [the eventBusName-parameter], the default event bus is used.
targets = client.list_targets_by_rule(Rule="test1")["Targets"]
assert len([t["Id"] for t in targets]) == 6
targets = client.list_targets_by_rule(Rule="test1", EventBusName="default")[
"Targets"
]
assert len([t["Id"] for t in targets]) == 6
targets = client.list_targets_by_rule(Rule="test1", EventBusName="newEventBus")[
"Targets"
]
assert [t["Id"] for t in targets] == ["newtarget"]
@mock_events
def test_remove_targets():
rule_name = get_random_rule()["Name"]