Logs: Support two subscription filters (#6724)

This commit is contained in:
Bert Blommers 2023-08-25 08:03:49 +00:00 committed by GitHub
parent 4d4cae08d2
commit a1adf241b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 87 deletions

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Tuple, Optional from typing import Any, Dict, Iterable, List, Tuple, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.core import CloudFormationModel from moto.core import CloudFormationModel
from moto.core.utils import unix_time_millis from moto.core.utils import unix_time_millis
@ -85,22 +85,20 @@ class LogEvent(BaseModel):
class LogStream(BaseModel): class LogStream(BaseModel):
_log_ids = 0 _log_ids = 0
def __init__(self, account_id: str, region: str, log_group: str, name: str): def __init__(self, log_group: "LogGroup", name: str):
self.account_id = account_id self.account_id = log_group.account_id
self.region = region self.region = log_group.region
self.arn = f"arn:aws:logs:{region}:{account_id}:log-group:{log_group}:log-stream:{name}" self.log_group = log_group
self.arn = f"arn:aws:logs:{self.region}:{self.account_id}:log-group:{log_group.name}:log-stream:{name}"
self.creation_time = int(unix_time_millis()) self.creation_time = int(unix_time_millis())
self.first_event_timestamp = None self.first_event_timestamp = None
self.last_event_timestamp = None self.last_event_timestamp = None
self.last_ingestion_time: Optional[int] = None self.last_ingestion_time: Optional[int] = None
self.log_stream_name = name self.log_stream_name = name
self.stored_bytes = 0 self.stored_bytes = 0
self.upload_sequence_token = ( # I'm guessing this is token needed for sequenceToken by put_events
0 # I'm guessing this is token needed for sequenceToken by put_events self.upload_sequence_token = 0
)
self.events: List[LogEvent] = [] self.events: List[LogEvent] = []
self.destination_arn: Optional[str] = None
self.filter_name: Optional[str] = None
self.__class__._log_ids += 1 self.__class__._log_ids += 1
@ -133,12 +131,7 @@ class LogStream(BaseModel):
res.update(rest) res.update(rest)
return res return res
def put_log_events( def put_log_events(self, log_events: List[Dict[str, Any]]) -> str:
self,
log_group_name: str,
log_stream_name: str,
log_events: List[Dict[str, Any]],
) -> str:
# TODO: ensure sequence_token # TODO: ensure sequence_token
# TODO: to be thread safe this would need a lock # TODO: to be thread safe this would need a lock
self.last_ingestion_time = int(unix_time_millis()) self.last_ingestion_time = int(unix_time_millis())
@ -152,9 +145,9 @@ class LogStream(BaseModel):
self.events += events self.events += events
self.upload_sequence_token += 1 self.upload_sequence_token += 1
service = None for subscription_filter in self.log_group.subscription_filters.values():
if self.destination_arn:
service = self.destination_arn.split(":")[2] service = subscription_filter.destination_arn.split(":")[2]
formatted_log_events = [ formatted_log_events = [
{ {
"id": event.event_id, "id": event.event_id,
@ -163,41 +156,54 @@ class LogStream(BaseModel):
} }
for event in events for event in events
] ]
self._send_log_events(
service=service,
destination_arn=subscription_filter.destination_arn,
filter_name=subscription_filter.name,
log_events=formatted_log_events,
)
return f"{self.upload_sequence_token:056d}"
def _send_log_events(
self,
service: str,
destination_arn: str,
filter_name: str,
log_events: List[Dict[str, Any]],
) -> None:
if service == "lambda": if service == "lambda":
from moto.awslambda import lambda_backends # due to circular dependency from moto.awslambda import lambda_backends # due to circular dependency
lambda_backends[self.account_id][self.region].send_log_event( lambda_backends[self.account_id][self.region].send_log_event(
self.destination_arn, destination_arn,
self.filter_name, filter_name,
log_group_name, self.log_group.name,
log_stream_name, self.log_stream_name,
formatted_log_events, log_events,
) )
elif service == "firehose": elif service == "firehose":
from moto.firehose import firehose_backends from moto.firehose import firehose_backends
firehose_backends[self.account_id][self.region].send_log_event( firehose_backends[self.account_id][self.region].send_log_event(
self.destination_arn, destination_arn,
self.filter_name, filter_name,
log_group_name, self.log_group.name,
log_stream_name, self.log_stream_name,
formatted_log_events, log_events,
) )
elif service == "kinesis": elif service == "kinesis":
from moto.kinesis import kinesis_backends from moto.kinesis import kinesis_backends
kinesis = kinesis_backends[self.account_id][self.region] kinesis = kinesis_backends[self.account_id][self.region]
kinesis.send_log_event( kinesis.send_log_event(
self.destination_arn, destination_arn,
self.filter_name, filter_name,
log_group_name, self.log_group.name,
log_stream_name, self.log_stream_name,
formatted_log_events, log_events,
) )
return f"{self.upload_sequence_token:056d}"
def get_log_events( def get_log_events(
self, self,
start_time: str, start_time: str,
@ -295,6 +301,39 @@ class LogStream(BaseModel):
return events return events
class SubscriptionFilter(BaseModel):
def __init__(
self,
name: str,
log_group_name: str,
filter_pattern: str,
destination_arn: str,
role_arn: str,
):
self.name = name
self.log_group_name = log_group_name
self.filter_pattern = filter_pattern
self.destination_arn = destination_arn
self.role_arn = role_arn
self.creation_time = int(unix_time_millis())
def update(self, filter_pattern: str, destination_arn: str, role_arn: str) -> None:
self.filter_pattern = filter_pattern
self.destination_arn = destination_arn
self.role_arn = role_arn
def to_json(self) -> Dict[str, Any]:
return {
"filterName": self.name,
"logGroupName": self.log_group_name,
"filterPattern": self.filter_pattern,
"destinationArn": self.destination_arn,
"roleArn": self.role_arn,
"distribution": "ByLogStream",
"creationTime": self.creation_time,
}
class LogGroup(CloudFormationModel): class LogGroup(CloudFormationModel):
def __init__( def __init__(
self, self,
@ -311,10 +350,9 @@ class LogGroup(CloudFormationModel):
self.creation_time = int(unix_time_millis()) self.creation_time = int(unix_time_millis())
self.tags = tags self.tags = tags
self.streams: Dict[str, LogStream] = dict() # {name: LogStream} self.streams: Dict[str, LogStream] = dict() # {name: LogStream}
self.retention_in_days = kwargs.get( # AWS defaults to Never Expire for log group retention
"RetentionInDays" self.retention_in_days = kwargs.get("RetentionInDays")
) # AWS defaults to Never Expire for log group retention self.subscription_filters: Dict[str, SubscriptionFilter] = {}
self.subscription_filters: List[Dict[str, Any]] = []
# The Amazon Resource Name (ARN) of the CMK to use when encrypting log data. It is optional. # The Amazon Resource Name (ARN) of the CMK to use when encrypting log data. It is optional.
# Docs: # Docs:
@ -365,12 +403,8 @@ class LogGroup(CloudFormationModel):
def create_log_stream(self, log_stream_name: str) -> None: def create_log_stream(self, log_stream_name: str) -> None:
if log_stream_name in self.streams: if log_stream_name in self.streams:
raise ResourceAlreadyExistsException() raise ResourceAlreadyExistsException()
stream = LogStream(self.account_id, self.region, self.name, log_stream_name) stream = LogStream(log_group=self, name=log_stream_name)
filters = self.describe_subscription_filters()
if filters:
stream.destination_arn = filters[0]["destinationArn"]
stream.filter_name = filters[0]["filterName"]
self.streams[log_stream_name] = stream self.streams[log_stream_name] = stream
def delete_log_stream(self, log_stream_name: str) -> None: def delete_log_stream(self, log_stream_name: str) -> None:
@ -433,14 +467,13 @@ class LogGroup(CloudFormationModel):
def put_log_events( def put_log_events(
self, self,
log_group_name: str,
log_stream_name: str, log_stream_name: str,
log_events: List[Dict[str, Any]], log_events: List[Dict[str, Any]],
) -> str: ) -> str:
if log_stream_name not in self.streams: if log_stream_name not in self.streams:
raise ResourceNotFoundException("The specified log stream does not exist.") raise ResourceNotFoundException("The specified log stream does not exist.")
stream = self.streams[log_stream_name] stream = self.streams[log_stream_name]
return stream.put_log_events(log_group_name, log_stream_name, log_events) return stream.put_log_events(log_events)
def get_log_events( def get_log_events(
self, self,
@ -556,47 +589,38 @@ class LogGroup(CloudFormationModel):
k: v for (k, v) in self.tags.items() if k not in tags_to_remove k: v for (k, v) in self.tags.items() if k not in tags_to_remove
} }
def describe_subscription_filters(self) -> List[Dict[str, Any]]: def describe_subscription_filters(self) -> Iterable[SubscriptionFilter]:
return self.subscription_filters return self.subscription_filters.values()
def put_subscription_filter( def put_subscription_filter(
self, filter_name: str, filter_pattern: str, destination_arn: str, role_arn: str self, filter_name: str, filter_pattern: str, destination_arn: str, role_arn: str
) -> None: ) -> None:
creation_time = int(unix_time_millis()) # only two subscription filters can be associated with a log group
if len(self.subscription_filters) == 2:
raise LimitExceededException()
# only one subscription filter can be associated with a log group # Update existing filter
if self.subscription_filters: if filter_name in self.subscription_filters:
if self.subscription_filters[0]["filterName"] == filter_name: self.subscription_filters[filter_name].update(
creation_time = self.subscription_filters[0]["creationTime"] filter_pattern, destination_arn, role_arn
else: )
raise LimitExceededException() return
for stream in self.streams.values(): self.subscription_filters[filter_name] = SubscriptionFilter(
stream.destination_arn = destination_arn name=filter_name,
stream.filter_name = filter_name log_group_name=self.name,
filter_pattern=filter_pattern,
self.subscription_filters = [ destination_arn=destination_arn,
{ role_arn=role_arn,
"filterName": filter_name, )
"logGroupName": self.name,
"filterPattern": filter_pattern,
"destinationArn": destination_arn,
"roleArn": role_arn,
"distribution": "ByLogStream",
"creationTime": creation_time,
}
]
def delete_subscription_filter(self, filter_name: str) -> None: def delete_subscription_filter(self, filter_name: str) -> None:
if ( if filter_name not in self.subscription_filters:
not self.subscription_filters
or self.subscription_filters[0]["filterName"] != filter_name
):
raise ResourceNotFoundException( raise ResourceNotFoundException(
"The specified subscription filter does not exist." "The specified subscription filter does not exist."
) )
self.subscription_filters = [] self.subscription_filters.pop(filter_name)
class LogResourcePolicy(CloudFormationModel): class LogResourcePolicy(CloudFormationModel):
@ -881,9 +905,7 @@ class LogsBackend(BaseBackend):
allowed_events.append(event) allowed_events.append(event)
last_timestamp = event["timestamp"] last_timestamp = event["timestamp"]
token = log_group.put_log_events( token = log_group.put_log_events(log_stream_name, allowed_events)
log_group_name, log_stream_name, allowed_events
)
return token, rejected_info return token, rejected_info
def get_log_events( def get_log_events(
@ -1034,7 +1056,7 @@ class LogsBackend(BaseBackend):
def describe_subscription_filters( def describe_subscription_filters(
self, log_group_name: str self, log_group_name: str
) -> List[Dict[str, Any]]: ) -> Iterable[SubscriptionFilter]:
log_group = self.groups.get(log_group_name) log_group = self.groups.get(log_group_name)
if not log_group: if not log_group:

View File

@ -371,11 +371,9 @@ class LogsResponse(BaseResponse):
def describe_subscription_filters(self) -> str: def describe_subscription_filters(self) -> str:
log_group_name = self._get_param("logGroupName") log_group_name = self._get_param("logGroupName")
subscription_filters = self.logs_backend.describe_subscription_filters( _filters = self.logs_backend.describe_subscription_filters(log_group_name)
log_group_name
)
return json.dumps({"subscriptionFilters": subscription_filters}) return json.dumps({"subscriptionFilters": [f.to_json() for f in _filters]})
def put_subscription_filter(self) -> str: def put_subscription_filter(self) -> str:
log_group_name = self._get_param("logGroupName") log_group_name = self._get_param("logGroupName")

View File

@ -62,7 +62,7 @@ def test_put_subscription_filter_update():
sub_filter["filterPattern"] = "" sub_filter["filterPattern"] = ""
# when # when
# to update an existing subscription filter the 'filerName' must be identical # to update an existing subscription filter the 'filterName' must be identical
client_logs.put_subscription_filter( client_logs.put_subscription_filter(
logGroupName=log_group_name, logGroupName=log_group_name,
filterName="test", filterName="test",
@ -82,11 +82,17 @@ def test_put_subscription_filter_update():
sub_filter["filterPattern"] = "[]" sub_filter["filterPattern"] = "[]"
# when # when
# only one subscription filter can be associated with a log group # only two subscription filters can be associated with a log group
client_logs.put_subscription_filter(
logGroupName=log_group_name,
filterName="test-2",
filterPattern="[]",
destinationArn=function_arn,
)
with pytest.raises(ClientError) as e: with pytest.raises(ClientError) as e:
client_logs.put_subscription_filter( client_logs.put_subscription_filter(
logGroupName=log_group_name, logGroupName=log_group_name,
filterName="test-2", filterName="test-3",
filterPattern="", filterPattern="",
destinationArn=function_arn, destinationArn=function_arn,
) )