feat(events): Add policy in put_permission (#4114)

* feat(events): Add policy in put_permission
Also add RemoveAllPermissions in remove_permission
This commit is contained in:
Gonzalo Saad 2021-08-03 11:10:36 -03:00 committed by GitHub
parent 0388b778dd
commit 242de5bc6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 47 deletions

View File

@ -7,6 +7,7 @@ import warnings
from collections import namedtuple from collections import namedtuple
from datetime import datetime from datetime import datetime
from enum import Enum, unique from enum import Enum, unique
from json import JSONDecodeError
from operator import lt, le, eq, ge, gt from operator import lt, le, eq, ge, gt
from boto3 import Session from boto3 import Session
@ -251,7 +252,7 @@ class EventBus(CloudFormationModel):
self.name = name self.name = name
self.tags = tags or [] self.tags = tags or []
self._permissions = {} self._statements = {}
@property @property
def arn(self): def arn(self):
@ -261,25 +262,16 @@ class EventBus(CloudFormationModel):
@property @property
def policy(self): def policy(self):
if not len(self._permissions): if self._statements:
return None policy = {
"Version": "2012-10-17",
"Statement": [stmt.describe() for stmt in self._statements.values()],
}
return json.dumps(policy)
return None
policy = {"Version": "2012-10-17", "Statement": []} def has_permissions(self):
return len(self._statements) > 0
for sid, permission in self._permissions.items():
policy["Statement"].append(
{
"Sid": sid,
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::{}:root".format(permission["Principal"])
},
"Action": permission["Action"],
"Resource": self.arn,
}
)
return json.dumps(policy)
def delete(self, region_name): def delete(self, region_name):
event_backend = events_backends[region_name] event_backend = events_backends[region_name]
@ -335,6 +327,70 @@ class EventBus(CloudFormationModel):
event_bus_name = resource_name event_bus_name = resource_name
event_backend.delete_event_bus(event_bus_name) event_backend.delete_event_bus(event_bus_name)
def _remove_principals_statements(self, *principals):
statements_to_delete = set()
for principal in principals:
for sid, statement in self._statements.items():
if statement.principal == principal:
statements_to_delete.add(sid)
# This is done separately to avoid:
# RuntimeError: dictionary changed size during iteration
for sid in statements_to_delete:
del self._statements[sid]
def add_permission(self, statement_id, action, principal):
self._remove_principals_statements(principal)
statement = EventBusPolicyStatement(
sid=statement_id, action=action, principal=principal, resource=self.arn,
)
self._statements[statement_id] = statement
def add_policy(self, policy):
policy_statements = policy["Statement"]
principals = [stmt["Principal"] for stmt in policy_statements]
self._remove_principals_statements(*principals)
for new_statement in policy_statements:
sid = new_statement["Sid"]
self._statements[sid] = EventBusPolicyStatement.from_dict(new_statement)
def remove_statement(self, sid):
return self._statements.pop(sid, None)
def remove_statements(self):
self._statements.clear()
class EventBusPolicyStatement:
def __init__(self, sid, principal, action, resource, effect="Allow"):
self.sid = sid
self.principal = principal
self.action = action
self.resource = resource
self.effect = effect
def describe(self):
return {
"Sid": self.sid,
"Effect": self.effect,
"Principal": self.principal,
"Action": self.action,
"Resource": self.resource,
}
@classmethod
def from_dict(cls, statement_dict):
return cls(
sid=statement_dict["Sid"],
effect=statement_dict["Effect"],
principal=statement_dict["Principal"],
action=statement_dict["Action"],
resource=statement_dict["Resource"],
)
class Archive(CloudFormationModel): class Archive(CloudFormationModel):
# https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_ListArchives.html#API_ListArchives_RequestParameters # https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_ListArchives.html#API_ListArchives_RequestParameters
@ -1073,49 +1129,65 @@ class EventsBackend(BaseBackend):
def test_event_pattern(self): def test_event_pattern(self):
raise NotImplementedError() raise NotImplementedError()
def put_permission(self, event_bus_name, action, principal, statement_id): @staticmethod
if not event_bus_name: def _put_permission_from_policy(event_bus, policy):
event_bus_name = "default" try:
policy_doc = json.loads(policy)
event_bus = self.describe_event_bus(event_bus_name) event_bus.add_policy(policy_doc)
except JSONDecodeError:
raise JsonRESTError(
"ValidationException", "This policy contains invalid Json"
)
def _put_permission_from_params(self, event_bus, action, principal, statement_id):
if principal is None or self.ACCOUNT_ID.match(principal) is None:
raise JsonRESTError(
"InvalidParameterValue", r"Principal must match ^(\d{1,12}|\*)$"
)
if action is None or action != "events:PutEvents": if action is None or action != "events:PutEvents":
raise JsonRESTError( raise JsonRESTError(
"ValidationException", "ValidationException",
"Provided value in parameter 'action' is not supported.", "Provided value in parameter 'action' is not supported.",
) )
if principal is None or self.ACCOUNT_ID.match(principal) is None:
raise JsonRESTError(
"InvalidParameterValue", r"Principal must match ^(\d{1,12}|\*)$"
)
if statement_id is None or self.STATEMENT_ID.match(statement_id) is None: if statement_id is None or self.STATEMENT_ID.match(statement_id) is None:
raise JsonRESTError( raise JsonRESTError(
"InvalidParameterValue", r"StatementId must match ^[a-zA-Z0-9-_]{1,64}$" "InvalidParameterValue", r"StatementId must match ^[a-zA-Z0-9-_]{1,64}$"
) )
event_bus._permissions[statement_id] = { principal = {"AWS": f"arn:aws:iam::{principal}:root"}
"Action": action, event_bus.add_permission(statement_id, action, principal)
"Principal": principal,
}
def remove_permission(self, event_bus_name, statement_id): def put_permission(self, event_bus_name, action, principal, statement_id, policy):
if not event_bus_name: if not event_bus_name:
event_bus_name = "default" event_bus_name = "default"
event_bus = self.describe_event_bus(event_bus_name) event_bus = self.describe_event_bus(event_bus_name)
if not len(event_bus._permissions): if policy:
raise JsonRESTError( self._put_permission_from_policy(event_bus, policy)
"ResourceNotFoundException", "EventBus does not have a policy." else:
) self._put_permission_from_params(event_bus, action, principal, statement_id)
if not event_bus._permissions.pop(statement_id, None): def remove_permission(self, event_bus_name, statement_id, remove_all_permissions):
raise JsonRESTError( if not event_bus_name:
"ResourceNotFoundException", event_bus_name = "default"
"Statement with the provided id does not exist.",
) event_bus = self.describe_event_bus(event_bus_name)
if remove_all_permissions:
event_bus.remove_statements()
else:
if not event_bus.has_permissions():
raise JsonRESTError(
"ResourceNotFoundException", "EventBus does not have a policy."
)
statement = event_bus.remove_statement(statement_id)
if not statement:
raise JsonRESTError(
"ResourceNotFoundException",
"Statement with the provided id does not exist.",
)
def describe_event_bus(self, name): def describe_event_bus(self, name):
if not name: if not name:
@ -1229,7 +1301,7 @@ class EventsBackend(BaseBackend):
"EventPattern": json.dumps(rule_event_pattern), "EventPattern": json.dumps(rule_event_pattern),
"EventBusName": event_bus.name, "EventBusName": event_bus.name,
"ManagedBy": "prod.vhs.events.aws.internal", "ManagedBy": "prod.vhs.events.aws.internal",
} },
) )
self.put_targets( self.put_targets(
rule.name, rule.name,

View File

@ -251,9 +251,10 @@ class EventsHandler(BaseResponse):
action = self._get_param("Action") action = self._get_param("Action")
principal = self._get_param("Principal") principal = self._get_param("Principal")
statement_id = self._get_param("StatementId") statement_id = self._get_param("StatementId")
policy = self._get_param("Policy")
self.events_backend.put_permission( self.events_backend.put_permission(
event_bus_name, action, principal, statement_id event_bus_name, action, principal, statement_id, policy
) )
return "" return ""
@ -261,8 +262,11 @@ class EventsHandler(BaseResponse):
def remove_permission(self): def remove_permission(self):
event_bus_name = self._get_param("EventBusName") event_bus_name = self._get_param("EventBusName")
statement_id = self._get_param("StatementId") statement_id = self._get_param("StatementId")
remove_all_permissions = self._get_param("RemoveAllPermissions")
self.events_backend.remove_permission(event_bus_name, statement_id) self.events_backend.remove_permission(
event_bus_name, statement_id, remove_all_permissions
)
return "" return ""

View File

@ -7,6 +7,7 @@ TestAccAWSCloudWatchDashboard
TestAccAWSCloudWatchEventApiDestination TestAccAWSCloudWatchEventApiDestination
TestAccAWSCloudWatchEventArchive TestAccAWSCloudWatchEventArchive
TestAccAWSCloudWatchEventBus TestAccAWSCloudWatchEventBus
TestAccAWSCloudwatchEventBusPolicy
TestAccAWSCloudWatchEventConnection TestAccAWSCloudWatchEventConnection
TestAccAWSCloudwatchLogGroupDataSource TestAccAWSCloudwatchLogGroupDataSource
TestAccAWSDataSourceCloudwatch TestAccAWSDataSourceCloudwatch