diff --git a/moto/logs/models.py b/moto/logs/models.py index 84ebc55e3..3fdefd520 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, Iterable, List, Tuple, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import CloudFormationModel from moto.core.utils import unix_time_millis @@ -85,22 +85,20 @@ class LogEvent(BaseModel): class LogStream(BaseModel): _log_ids = 0 - def __init__(self, account_id: str, region: str, log_group: str, name: str): - self.account_id = account_id - self.region = region - self.arn = f"arn:aws:logs:{region}:{account_id}:log-group:{log_group}:log-stream:{name}" + def __init__(self, log_group: "LogGroup", name: str): + self.account_id = log_group.account_id + self.region = log_group.region + self.log_group = log_group + self.arn = f"arn:aws:logs:{self.region}:{self.account_id}:log-group:{log_group.name}:log-stream:{name}" self.creation_time = int(unix_time_millis()) self.first_event_timestamp = None self.last_event_timestamp = None self.last_ingestion_time: Optional[int] = None self.log_stream_name = name self.stored_bytes = 0 - self.upload_sequence_token = ( - 0 # I'm guessing this is token needed for sequenceToken by put_events - ) + # I'm guessing this is token needed for sequenceToken by put_events + self.upload_sequence_token = 0 self.events: List[LogEvent] = [] - self.destination_arn: Optional[str] = None - self.filter_name: Optional[str] = None self.__class__._log_ids += 1 @@ -133,12 +131,7 @@ class LogStream(BaseModel): res.update(rest) return res - def put_log_events( - self, - log_group_name: str, - log_stream_name: str, - log_events: List[Dict[str, Any]], - ) -> str: + def put_log_events(self, log_events: List[Dict[str, Any]]) -> str: # TODO: ensure sequence_token # TODO: to be thread safe this would need a lock self.last_ingestion_time = int(unix_time_millis()) @@ -152,9 +145,9 @@ class LogStream(BaseModel): self.events += events self.upload_sequence_token += 1 - service = None - if self.destination_arn: - service = self.destination_arn.split(":")[2] + for subscription_filter in self.log_group.subscription_filters.values(): + + service = subscription_filter.destination_arn.split(":")[2] formatted_log_events = [ { "id": event.event_id, @@ -163,41 +156,54 @@ class LogStream(BaseModel): } for event in events ] + self._send_log_events( + service=service, + destination_arn=subscription_filter.destination_arn, + filter_name=subscription_filter.name, + log_events=formatted_log_events, + ) + return f"{self.upload_sequence_token:056d}" + + def _send_log_events( + self, + service: str, + destination_arn: str, + filter_name: str, + log_events: List[Dict[str, Any]], + ) -> None: if service == "lambda": from moto.awslambda import lambda_backends # due to circular dependency lambda_backends[self.account_id][self.region].send_log_event( - self.destination_arn, - self.filter_name, - log_group_name, - log_stream_name, - formatted_log_events, + destination_arn, + filter_name, + self.log_group.name, + self.log_stream_name, + log_events, ) elif service == "firehose": from moto.firehose import firehose_backends firehose_backends[self.account_id][self.region].send_log_event( - self.destination_arn, - self.filter_name, - log_group_name, - log_stream_name, - formatted_log_events, + destination_arn, + filter_name, + self.log_group.name, + self.log_stream_name, + log_events, ) elif service == "kinesis": from moto.kinesis import kinesis_backends kinesis = kinesis_backends[self.account_id][self.region] kinesis.send_log_event( - self.destination_arn, - self.filter_name, - log_group_name, - log_stream_name, - formatted_log_events, + destination_arn, + filter_name, + self.log_group.name, + self.log_stream_name, + log_events, ) - return f"{self.upload_sequence_token:056d}" - def get_log_events( self, start_time: str, @@ -295,6 +301,39 @@ class LogStream(BaseModel): return events +class SubscriptionFilter(BaseModel): + def __init__( + self, + name: str, + log_group_name: str, + filter_pattern: str, + destination_arn: str, + role_arn: str, + ): + self.name = name + self.log_group_name = log_group_name + self.filter_pattern = filter_pattern + self.destination_arn = destination_arn + self.role_arn = role_arn + self.creation_time = int(unix_time_millis()) + + def update(self, filter_pattern: str, destination_arn: str, role_arn: str) -> None: + self.filter_pattern = filter_pattern + self.destination_arn = destination_arn + self.role_arn = role_arn + + def to_json(self) -> Dict[str, Any]: + return { + "filterName": self.name, + "logGroupName": self.log_group_name, + "filterPattern": self.filter_pattern, + "destinationArn": self.destination_arn, + "roleArn": self.role_arn, + "distribution": "ByLogStream", + "creationTime": self.creation_time, + } + + class LogGroup(CloudFormationModel): def __init__( self, @@ -311,10 +350,9 @@ class LogGroup(CloudFormationModel): self.creation_time = int(unix_time_millis()) self.tags = tags self.streams: Dict[str, LogStream] = dict() # {name: LogStream} - self.retention_in_days = kwargs.get( - "RetentionInDays" - ) # AWS defaults to Never Expire for log group retention - self.subscription_filters: List[Dict[str, Any]] = [] + # AWS defaults to Never Expire for log group retention + self.retention_in_days = kwargs.get("RetentionInDays") + self.subscription_filters: Dict[str, SubscriptionFilter] = {} # The Amazon Resource Name (ARN) of the CMK to use when encrypting log data. It is optional. # Docs: @@ -365,12 +403,8 @@ class LogGroup(CloudFormationModel): def create_log_stream(self, log_stream_name: str) -> None: if log_stream_name in self.streams: raise ResourceAlreadyExistsException() - stream = LogStream(self.account_id, self.region, self.name, log_stream_name) - filters = self.describe_subscription_filters() + stream = LogStream(log_group=self, name=log_stream_name) - if filters: - stream.destination_arn = filters[0]["destinationArn"] - stream.filter_name = filters[0]["filterName"] self.streams[log_stream_name] = stream def delete_log_stream(self, log_stream_name: str) -> None: @@ -433,14 +467,13 @@ class LogGroup(CloudFormationModel): def put_log_events( self, - log_group_name: str, log_stream_name: str, log_events: List[Dict[str, Any]], ) -> str: if log_stream_name not in self.streams: raise ResourceNotFoundException("The specified log stream does not exist.") stream = self.streams[log_stream_name] - return stream.put_log_events(log_group_name, log_stream_name, log_events) + return stream.put_log_events(log_events) def get_log_events( self, @@ -556,47 +589,38 @@ class LogGroup(CloudFormationModel): k: v for (k, v) in self.tags.items() if k not in tags_to_remove } - def describe_subscription_filters(self) -> List[Dict[str, Any]]: - return self.subscription_filters + def describe_subscription_filters(self) -> Iterable[SubscriptionFilter]: + return self.subscription_filters.values() def put_subscription_filter( self, filter_name: str, filter_pattern: str, destination_arn: str, role_arn: str ) -> None: - creation_time = int(unix_time_millis()) + # only two subscription filters can be associated with a log group + if len(self.subscription_filters) == 2: + raise LimitExceededException() - # only one subscription filter can be associated with a log group - if self.subscription_filters: - if self.subscription_filters[0]["filterName"] == filter_name: - creation_time = self.subscription_filters[0]["creationTime"] - else: - raise LimitExceededException() + # Update existing filter + if filter_name in self.subscription_filters: + self.subscription_filters[filter_name].update( + filter_pattern, destination_arn, role_arn + ) + return - for stream in self.streams.values(): - stream.destination_arn = destination_arn - stream.filter_name = filter_name - - self.subscription_filters = [ - { - "filterName": filter_name, - "logGroupName": self.name, - "filterPattern": filter_pattern, - "destinationArn": destination_arn, - "roleArn": role_arn, - "distribution": "ByLogStream", - "creationTime": creation_time, - } - ] + self.subscription_filters[filter_name] = SubscriptionFilter( + name=filter_name, + log_group_name=self.name, + filter_pattern=filter_pattern, + destination_arn=destination_arn, + role_arn=role_arn, + ) def delete_subscription_filter(self, filter_name: str) -> None: - if ( - not self.subscription_filters - or self.subscription_filters[0]["filterName"] != filter_name - ): + if filter_name not in self.subscription_filters: raise ResourceNotFoundException( "The specified subscription filter does not exist." ) - self.subscription_filters = [] + self.subscription_filters.pop(filter_name) class LogResourcePolicy(CloudFormationModel): @@ -881,9 +905,7 @@ class LogsBackend(BaseBackend): allowed_events.append(event) last_timestamp = event["timestamp"] - token = log_group.put_log_events( - log_group_name, log_stream_name, allowed_events - ) + token = log_group.put_log_events(log_stream_name, allowed_events) return token, rejected_info def get_log_events( @@ -1034,7 +1056,7 @@ class LogsBackend(BaseBackend): def describe_subscription_filters( self, log_group_name: str - ) -> List[Dict[str, Any]]: + ) -> Iterable[SubscriptionFilter]: log_group = self.groups.get(log_group_name) if not log_group: diff --git a/moto/logs/responses.py b/moto/logs/responses.py index 836166e83..ba0e3be1f 100644 --- a/moto/logs/responses.py +++ b/moto/logs/responses.py @@ -371,11 +371,9 @@ class LogsResponse(BaseResponse): def describe_subscription_filters(self) -> str: log_group_name = self._get_param("logGroupName") - subscription_filters = self.logs_backend.describe_subscription_filters( - log_group_name - ) + _filters = self.logs_backend.describe_subscription_filters(log_group_name) - return json.dumps({"subscriptionFilters": subscription_filters}) + return json.dumps({"subscriptionFilters": [f.to_json() for f in _filters]}) def put_subscription_filter(self) -> str: log_group_name = self._get_param("logGroupName") diff --git a/tests/test_logs/test_integration.py b/tests/test_logs/test_integration.py index d8e326a37..5a86be97e 100644 --- a/tests/test_logs/test_integration.py +++ b/tests/test_logs/test_integration.py @@ -62,7 +62,7 @@ def test_put_subscription_filter_update(): sub_filter["filterPattern"] = "" # when - # to update an existing subscription filter the 'filerName' must be identical + # to update an existing subscription filter the 'filterName' must be identical client_logs.put_subscription_filter( logGroupName=log_group_name, filterName="test", @@ -82,11 +82,17 @@ def test_put_subscription_filter_update(): sub_filter["filterPattern"] = "[]" # when - # only one subscription filter can be associated with a log group + # only two subscription filters can be associated with a log group + client_logs.put_subscription_filter( + logGroupName=log_group_name, + filterName="test-2", + filterPattern="[]", + destinationArn=function_arn, + ) with pytest.raises(ClientError) as e: client_logs.put_subscription_filter( logGroupName=log_group_name, - filterName="test-2", + filterName="test-3", filterPattern="", destinationArn=function_arn, )