Techdebt: MyPy L (#6120)

This commit is contained in:
Bert Blommers 2023-03-24 09:43:51 -01:00 committed by GitHub
parent 3adbb8136a
commit 93eb669af4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 311 additions and 230 deletions

View File

@ -1,3 +1,4 @@
from typing import Any, Optional
from moto.core.exceptions import JsonRESTError
@ -6,7 +7,7 @@ class LogsClientError(JsonRESTError):
class ResourceNotFoundException(LogsClientError):
def __init__(self, msg=None):
def __init__(self, msg: Optional[str] = None):
self.code = 400
super().__init__(
"ResourceNotFoundException", msg or "The specified log group does not exist"
@ -14,7 +15,13 @@ class ResourceNotFoundException(LogsClientError):
class InvalidParameterException(LogsClientError):
def __init__(self, msg=None, constraint=None, parameter=None, value=None):
def __init__(
self,
msg: Optional[str] = None,
constraint: Optional[str] = None,
parameter: Optional[str] = None,
value: Any = None,
):
self.code = 400
if constraint:
msg = f"1 validation error detected: Value '{value}' at '{parameter}' failed to satisfy constraint: {constraint}"
@ -24,7 +31,7 @@ class InvalidParameterException(LogsClientError):
class ResourceAlreadyExistsException(LogsClientError):
def __init__(self):
def __init__(self) -> None:
self.code = 400
super().__init__(
"ResourceAlreadyExistsException", "The specified log group already exists"
@ -32,6 +39,6 @@ class ResourceAlreadyExistsException(LogsClientError):
class LimitExceededException(LogsClientError):
def __init__(self):
def __init__(self) -> None:
self.code = 400
super().__init__("LimitExceededException", "Resource limit exceeded.")

View File

@ -1,22 +1,35 @@
def find_metric_transformation_by_name(metric_transformations, metric_name):
from typing import Any, Dict, List, Optional
def find_metric_transformation_by_name(
metric_transformations: List[Dict[str, Any]], metric_name: str
) -> Optional[Dict[str, Any]]:
for metric in metric_transformations:
if metric["metricName"] == metric_name:
return metric
return None
def find_metric_transformation_by_namespace(metric_transformations, metric_namespace):
def find_metric_transformation_by_namespace(
metric_transformations: List[Dict[str, Any]], metric_namespace: str
) -> Optional[Dict[str, Any]]:
for metric in metric_transformations:
if metric["metricNamespace"] == metric_namespace:
return metric
return None
class MetricFilters:
def __init__(self):
self.metric_filters = []
def __init__(self) -> None:
self.metric_filters: List[Dict[str, Any]] = []
def add_filter(
self, filter_name, filter_pattern, log_group_name, metric_transformations
):
self,
filter_name: str,
filter_pattern: str,
log_group_name: str,
metric_transformations: str,
) -> None:
self.metric_filters.append(
{
"filterName": filter_name,
@ -27,9 +40,13 @@ class MetricFilters:
)
def get_matching_filters(
self, prefix=None, log_group_name=None, metric_name=None, metric_namespace=None
):
result = []
self,
prefix: Optional[str] = None,
log_group_name: Optional[str] = None,
metric_name: Optional[str] = None,
metric_namespace: Optional[str] = None,
) -> List[Dict[str, Any]]:
result: List[Dict[str, Any]] = []
for f in self.metric_filters:
prefix_matches = prefix is None or f["filterName"].startswith(prefix)
log_group_matches = (
@ -58,7 +75,9 @@ class MetricFilters:
return result
def delete_filter(self, filter_name=None, log_group_name=None):
def delete_filter(
self, filter_name: Optional[str] = None, log_group_name: Optional[str] = None
) -> List[Dict[str, Any]]:
for f in self.metric_filters:
if f["filterName"] == filter_name and f["logGroupName"] == log_group_name:
self.metric_filters.remove(f)

View File

@ -19,7 +19,7 @@ MAX_RESOURCE_POLICIES_PER_REGION = 10
class LogQuery(BaseModel):
def __init__(self, query_id, start_time, end_time, query):
def __init__(self, query_id: str, start_time: str, end_time: str, query: str):
self.query_id = query_id
self.start_time = start_time
self.end_time = end_time
@ -29,7 +29,7 @@ class LogQuery(BaseModel):
class LogEvent(BaseModel):
_event_id = 0
def __init__(self, ingestion_time, log_event):
def __init__(self, ingestion_time: int, log_event: Dict[str, Any]):
self.ingestion_time = ingestion_time
self.timestamp = log_event["timestamp"]
self.message = log_event["message"]
@ -37,7 +37,7 @@ class LogEvent(BaseModel):
self.__class__._event_id += 1
""
def to_filter_dict(self):
def to_filter_dict(self) -> Dict[str, Any]:
return {
"eventId": str(self.event_id),
"ingestionTime": self.ingestion_time,
@ -46,7 +46,7 @@ class LogEvent(BaseModel):
"timestamp": self.timestamp,
}
def to_response_dict(self):
def to_response_dict(self) -> Dict[str, Any]:
return {
"ingestionTime": self.ingestion_time,
"message": self.message,
@ -57,26 +57,26 @@ class LogEvent(BaseModel):
class LogStream(BaseModel):
_log_ids = 0
def __init__(self, account_id, region, log_group, name):
def __init__(self, account_id: str, region: str, log_group: str, name: str):
self.account_id = account_id
self.region = region
self.arn = f"arn:aws:logs:{region}:{account_id}:log-group:{log_group}:log-stream:{name}"
self.creation_time = int(unix_time_millis())
self.first_event_timestamp = None
self.last_event_timestamp = None
self.last_ingestion_time = None
self.last_ingestion_time: Optional[int] = None
self.log_stream_name = name
self.stored_bytes = 0
self.upload_sequence_token = (
0 # I'm guessing this is token needed for sequenceToken by put_events
)
self.events = []
self.destination_arn = None
self.filter_name = None
self.events: List[LogEvent] = []
self.destination_arn: Optional[str] = None
self.filter_name: Optional[str] = None
self.__class__._log_ids += 1
def _update(self):
def _update(self) -> None:
# events can be empty when stream is described soon after creation
self.first_event_timestamp = (
min([x.timestamp for x in self.events]) if self.events else None
@ -85,7 +85,7 @@ class LogStream(BaseModel):
max([x.timestamp for x in self.events]) if self.events else None
)
def to_describe_dict(self):
def to_describe_dict(self) -> Dict[str, Any]:
# Compute start and end times
self._update()
@ -105,7 +105,12 @@ class LogStream(BaseModel):
res.update(rest)
return res
def put_log_events(self, log_group_name, log_stream_name, log_events):
def put_log_events(
self,
log_group_name: str,
log_stream_name: str,
log_events: List[Dict[str, Any]],
) -> str:
# TODO: ensure sequence_token
# TODO: to be thread safe this would need a lock
self.last_ingestion_time = int(unix_time_millis())
@ -139,7 +144,7 @@ class LogStream(BaseModel):
self.filter_name,
log_group_name,
log_stream_name,
formatted_log_events,
formatted_log_events, # type: ignore
)
elif service == "firehose":
from moto.firehose import firehose_backends
@ -149,23 +154,23 @@ class LogStream(BaseModel):
self.filter_name,
log_group_name,
log_stream_name,
formatted_log_events,
formatted_log_events, # type: ignore
)
return f"{self.upload_sequence_token:056d}"
def get_log_events(
self,
start_time,
end_time,
limit,
next_token,
start_from_head,
):
start_time: str,
end_time: str,
limit: int,
next_token: Optional[str],
start_from_head: str,
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[str]]:
if limit is None:
limit = 10000
def filter_func(event):
def filter_func(event: LogEvent) -> bool:
if start_time and event.timestamp < start_time:
return False
@ -174,7 +179,9 @@ class LogStream(BaseModel):
return True
def get_index_and_direction_from_token(token):
def get_index_and_direction_from_token(
token: Optional[str],
) -> Tuple[Optional[str], int]:
if token is not None:
try:
return token[0], int(token[2:])
@ -224,8 +231,10 @@ class LogStream(BaseModel):
return (events_page, f"b/{start_index:056d}", f"f/{end_index:056d}")
def filter_log_events(self, start_time, end_time, filter_pattern):
def filter_func(event):
def filter_log_events(
self, start_time: int, end_time: int, filter_pattern: str
) -> List[Dict[str, Any]]:
def filter_func(event: LogEvent) -> bool:
if start_time and event.timestamp < start_time:
return False
@ -237,7 +246,7 @@ class LogStream(BaseModel):
return True
events = []
events: List[Dict[str, Any]] = []
for event in sorted(
filter(filter_func, self.events), key=lambda x: x.timestamp
):
@ -248,18 +257,25 @@ class LogStream(BaseModel):
class LogGroup(CloudFormationModel):
def __init__(self, account_id, region, name, tags, **kwargs):
def __init__(
self,
account_id: str,
region: str,
name: str,
tags: Optional[Dict[str, str]],
**kwargs: Any,
):
self.name = name
self.account_id = account_id
self.region = region
self.arn = f"arn:aws:logs:{region}:{account_id}:log-group:{name}"
self.creation_time = int(unix_time_millis())
self.tags = tags
self.streams = dict() # {name: LogStream}
self.streams: Dict[str, LogStream] = dict() # {name: LogStream}
self.retention_in_days = kwargs.get(
"RetentionInDays"
) # AWS defaults to Never Expire for log group retention
self.subscription_filters = []
self.subscription_filters: List[Dict[str, Any]] = []
# The Amazon Resource Name (ARN) of the CMK to use when encrypting log data. It is optional.
# Docs:
@ -267,25 +283,30 @@ class LogGroup(CloudFormationModel):
self.kms_key_id = kwargs.get("kmsKeyId")
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "LogGroupName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-logs-loggroup.html
return "AWS::Logs::LogGroup"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "LogGroup":
properties = cloudformation_json["Properties"]
tags = properties.get("Tags", {})
return logs_backends[account_id][region_name].create_log_group(
resource_name, tags, **properties
)
def create_log_stream(self, log_stream_name):
def create_log_stream(self, log_stream_name: str) -> None:
if log_stream_name in self.streams:
raise ResourceAlreadyExistsException()
stream = LogStream(self.account_id, self.region, self.name, log_stream_name)
@ -296,20 +317,20 @@ class LogGroup(CloudFormationModel):
stream.filter_name = filters[0]["filterName"]
self.streams[log_stream_name] = stream
def delete_log_stream(self, log_stream_name):
def delete_log_stream(self, log_stream_name: str) -> None:
if log_stream_name not in self.streams:
raise ResourceNotFoundException()
del self.streams[log_stream_name]
def describe_log_streams(
self,
descending,
log_group_name,
log_stream_name_prefix,
order_by,
next_token=None,
limit=None,
):
descending: bool,
log_group_name: str,
log_stream_name_prefix: str,
order_by: str,
limit: int,
next_token: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
# responses only log_stream_name, creation_time, arn, stored_bytes when no events are stored.
log_streams = [
@ -318,7 +339,7 @@ class LogGroup(CloudFormationModel):
if name.startswith(log_stream_name_prefix)
]
def sorter(item):
def sorter(item: Any) -> Any:
return (
item[0]
if order_by == "LogStreamName"
@ -354,7 +375,12 @@ class LogGroup(CloudFormationModel):
return log_streams_page, new_token
def put_log_events(self, log_group_name, log_stream_name, log_events):
def put_log_events(
self,
log_group_name: str,
log_stream_name: str,
log_events: List[Dict[str, Any]],
) -> str:
if log_stream_name not in self.streams:
raise ResourceNotFoundException("The specified log stream does not exist.")
stream = self.streams[log_stream_name]
@ -362,13 +388,13 @@ class LogGroup(CloudFormationModel):
def get_log_events(
self,
log_stream_name,
start_time,
end_time,
limit,
next_token,
start_from_head,
):
log_stream_name: str,
start_time: str,
end_time: str,
limit: int,
next_token: Optional[str],
start_from_head: str,
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[str]]:
if log_stream_name not in self.streams:
raise ResourceNotFoundException()
stream = self.streams[log_stream_name]
@ -382,15 +408,15 @@ class LogGroup(CloudFormationModel):
def filter_log_events(
self,
log_group_name,
log_stream_names,
start_time,
end_time,
limit,
next_token,
filter_pattern,
interleaved,
):
log_group_name: str,
log_stream_names: List[str],
start_time: int,
end_time: int,
limit: Optional[int],
next_token: Optional[str],
filter_pattern: str,
interleaved: bool,
) -> Tuple[List[Dict[str, Any]], Optional[str], List[Dict[str, Any]]]:
if not limit:
limit = 10000
streams = [
@ -409,14 +435,15 @@ class LogGroup(CloudFormationModel):
first_index = 0
if next_token:
try:
group, stream, event_id = next_token.split("@")
group, stream_name, event_id = next_token.split("@")
if group != log_group_name:
raise ValueError()
first_index = (
next(
index
for (index, e) in enumerate(events)
if e["logStreamName"] == stream and e["eventId"] == event_id
if e["logStreamName"] == stream_name
and e["eventId"] == event_id
)
+ 1
)
@ -440,7 +467,7 @@ class LogGroup(CloudFormationModel):
]
return events_page, next_token, searched_streams
def to_describe_dict(self):
def to_describe_dict(self) -> Dict[str, Any]:
log_group = {
"arn": self.arn,
"creationTime": self.creation_time,
@ -455,30 +482,30 @@ class LogGroup(CloudFormationModel):
log_group["kmsKeyId"] = self.kms_key_id
return log_group
def set_retention_policy(self, retention_in_days):
def set_retention_policy(self, retention_in_days: Optional[str]) -> None:
self.retention_in_days = retention_in_days
def list_tags(self):
def list_tags(self) -> Dict[str, str]:
return self.tags if self.tags else {}
def tag(self, tags):
def tag(self, tags: Dict[str, str]) -> None:
if self.tags:
self.tags.update(tags)
else:
self.tags = tags
def untag(self, tags_to_remove):
def untag(self, tags_to_remove: List[str]) -> None:
if self.tags:
self.tags = {
k: v for (k, v) in self.tags.items() if k not in tags_to_remove
}
def describe_subscription_filters(self):
def describe_subscription_filters(self) -> List[Dict[str, Any]]:
return self.subscription_filters
def put_subscription_filter(
self, filter_name, filter_pattern, destination_arn, role_arn
):
self, filter_name: str, filter_pattern: str, destination_arn: str, role_arn: str
) -> None:
creation_time = int(unix_time_millis())
# only one subscription filter can be associated with a log group
@ -504,7 +531,7 @@ class LogGroup(CloudFormationModel):
}
]
def delete_subscription_filter(self, filter_name):
def delete_subscription_filter(self, filter_name: str) -> None:
if (
not self.subscription_filters
or self.subscription_filters[0]["filterName"] != filter_name
@ -517,16 +544,16 @@ class LogGroup(CloudFormationModel):
class LogResourcePolicy(CloudFormationModel):
def __init__(self, policy_name, policy_document):
def __init__(self, policy_name: str, policy_document: str):
self.policy_name = policy_name
self.policy_document = policy_document
self.last_updated_time = int(unix_time_millis())
def update(self, policy_document):
def update(self, policy_document: str) -> None:
self.policy_document = policy_document
self.last_updated_time = int(unix_time_millis())
def describe(self):
def describe(self) -> Dict[str, Any]:
return {
"policyName": self.policy_name,
"policyDocument": self.policy_document,
@ -534,22 +561,27 @@ class LogResourcePolicy(CloudFormationModel):
}
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.policy_name
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "PolicyName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-logs-resourcepolicy.html
return "AWS::Logs::ResourcePolicy"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "LogResourcePolicy":
properties = cloudformation_json["Properties"]
policy_name = properties["PolicyName"]
policy_document = properties["PolicyDocument"]
@ -558,14 +590,14 @@ class LogResourcePolicy(CloudFormationModel):
)
@classmethod
def update_from_cloudformation_json(
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource,
new_resource_name,
cloudformation_json,
account_id,
region_name,
):
original_resource: Any,
new_resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> "LogResourcePolicy":
properties = cloudformation_json["Properties"]
policy_name = properties["PolicyName"]
policy_document = properties["PolicyDocument"]
@ -578,30 +610,36 @@ class LogResourcePolicy(CloudFormationModel):
return updated
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name
):
return logs_backends[account_id][region_name].delete_resource_policy(
resource_name
)
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
logs_backends[account_id][region_name].delete_resource_policy(resource_name)
class LogsBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.groups = dict() # { logGroupName: LogGroup}
self.groups: Dict[str, LogGroup] = dict()
self.filters = MetricFilters()
self.queries = dict()
self.resource_policies = dict()
self.queries: Dict[str, LogQuery] = dict()
self.resource_policies: Dict[str, LogResourcePolicy] = dict()
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "logs"
)
def create_log_group(self, log_group_name, tags, **kwargs):
def create_log_group(
self, log_group_name: str, tags: Dict[str, str], **kwargs: Any
) -> LogGroup:
if log_group_name in self.groups:
raise ResourceAlreadyExistsException()
if len(log_group_name) > 512:
@ -624,30 +662,27 @@ class LogsBackend(BaseBackend):
self.account_id, self.region_name, log_group_name, tags
)
def delete_log_group(self, log_group_name):
def delete_log_group(self, log_group_name: str) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
del self.groups[log_group_name]
@paginate(pagination_model=PAGINATION_MODEL)
def describe_log_groups(self, log_group_name_prefix=None):
if log_group_name_prefix is None:
log_group_name_prefix = ""
def describe_log_groups(self, log_group_name_prefix: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore[misc]
groups = [
group.to_describe_dict()
for name, group in self.groups.items()
if name.startswith(log_group_name_prefix)
if name.startswith(log_group_name_prefix or "")
]
groups = sorted(groups, key=lambda x: x["logGroupName"])
return groups
def create_log_stream(self, log_group_name: str, log_stream_name: str) -> LogStream:
def create_log_stream(self, log_group_name: str, log_stream_name: str) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
log_group = self.groups[log_group_name]
return log_group.create_log_stream(log_stream_name)
log_group.create_log_stream(log_stream_name)
def ensure_log_stream(self, log_group_name: str, log_stream_name: str) -> None:
if log_group_name not in self.groups:
@ -658,21 +693,21 @@ class LogsBackend(BaseBackend):
self.create_log_stream(log_group_name, log_stream_name)
def delete_log_stream(self, log_group_name, log_stream_name):
def delete_log_stream(self, log_group_name: str, log_stream_name: str) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
log_group = self.groups[log_group_name]
return log_group.delete_log_stream(log_stream_name)
log_group.delete_log_stream(log_stream_name)
def describe_log_streams(
self,
descending,
limit,
log_group_name,
log_stream_name_prefix,
next_token,
order_by,
):
descending: bool,
limit: int,
log_group_name: str,
log_stream_name_prefix: str,
next_token: Optional[str],
order_by: str,
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
if limit > 50:
@ -740,14 +775,14 @@ class LogsBackend(BaseBackend):
def get_log_events(
self,
log_group_name,
log_stream_name,
start_time,
end_time,
limit,
next_token,
start_from_head,
):
log_group_name: str,
log_stream_name: str,
start_time: str,
end_time: str,
limit: int,
next_token: Optional[str],
start_from_head: str,
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[str]]:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
if limit and limit > 1000:
@ -763,15 +798,15 @@ class LogsBackend(BaseBackend):
def filter_log_events(
self,
log_group_name,
log_stream_names,
start_time,
end_time,
limit,
next_token,
filter_pattern,
interleaved,
):
log_group_name: str,
log_stream_names: List[str],
start_time: int,
end_time: int,
limit: Optional[int],
next_token: Optional[str],
filter_pattern: str,
interleaved: bool,
) -> Tuple[List[Dict[str, Any]], Optional[str], List[Dict[str, Any]]]:
"""
The following filter patterns are currently supported: Single Terms, Multiple Terms, Exact Phrases.
If the pattern is not supported, all events are returned.
@ -796,32 +831,29 @@ class LogsBackend(BaseBackend):
interleaved,
)
def put_retention_policy(self, log_group_name, retention_in_days):
def put_retention_policy(self, log_group_name: str, retention_in_days: str) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
log_group = self.groups[log_group_name]
return log_group.set_retention_policy(retention_in_days)
self.groups[log_group_name].set_retention_policy(retention_in_days)
def delete_retention_policy(self, log_group_name):
def delete_retention_policy(self, log_group_name: str) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
log_group = self.groups[log_group_name]
return log_group.set_retention_policy(None)
self.groups[log_group_name].set_retention_policy(None)
def describe_resource_policies(
self, next_token, limit
): # pylint: disable=unused-argument
def describe_resource_policies(self) -> List[LogResourcePolicy]:
"""Return list of resource policies.
The next_token and limit arguments are ignored. The maximum
number of resource policies per region is a small number (less
than 50), so pagination isn't needed.
"""
limit = limit or MAX_RESOURCE_POLICIES_PER_REGION
return list(self.resource_policies.values())
def put_resource_policy(self, policy_name, policy_doc):
def put_resource_policy(
self, policy_name: str, policy_doc: str
) -> LogResourcePolicy:
"""Creates/updates resource policy and return policy object"""
if policy_name in self.resource_policies:
policy = self.resource_policies[policy_name]
@ -833,52 +865,63 @@ class LogsBackend(BaseBackend):
self.resource_policies[policy_name] = policy
return policy
def delete_resource_policy(self, policy_name):
def delete_resource_policy(self, policy_name: str) -> None:
"""Remove resource policy with a policy name matching given name."""
if policy_name not in self.resource_policies:
raise ResourceNotFoundException(
msg=f"Policy with name [{policy_name}] does not exist"
)
del self.resource_policies[policy_name]
return ""
def list_tags_log_group(self, log_group_name):
def list_tags_log_group(self, log_group_name: str) -> Dict[str, str]:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
log_group = self.groups[log_group_name]
return log_group.list_tags()
def tag_log_group(self, log_group_name, tags):
def tag_log_group(self, log_group_name: str, tags: Dict[str, str]) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
log_group = self.groups[log_group_name]
log_group.tag(tags)
def untag_log_group(self, log_group_name, tags):
def untag_log_group(self, log_group_name: str, tags: List[str]) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
log_group = self.groups[log_group_name]
log_group.untag(tags)
def put_metric_filter(
self, filter_name, filter_pattern, log_group_name, metric_transformations
):
self,
filter_name: str,
filter_pattern: str,
log_group_name: str,
metric_transformations: str,
) -> None:
self.filters.add_filter(
filter_name, filter_pattern, log_group_name, metric_transformations
)
def describe_metric_filters(
self, prefix=None, log_group_name=None, metric_name=None, metric_namespace=None
):
self,
prefix: Optional[str] = None,
log_group_name: Optional[str] = None,
metric_name: Optional[str] = None,
metric_namespace: Optional[str] = None,
) -> List[Dict[str, Any]]:
filters = self.filters.get_matching_filters(
prefix, log_group_name, metric_name, metric_namespace
)
return filters
def delete_metric_filter(self, filter_name=None, log_group_name=None):
def delete_metric_filter(
self, filter_name: Optional[str] = None, log_group_name: Optional[str] = None
) -> None:
self.filters.delete_filter(filter_name, log_group_name)
def describe_subscription_filters(self, log_group_name):
def describe_subscription_filters(
self, log_group_name: str
) -> List[Dict[str, Any]]:
log_group = self.groups.get(log_group_name)
if not log_group:
@ -887,8 +930,13 @@ class LogsBackend(BaseBackend):
return log_group.describe_subscription_filters()
def put_subscription_filter(
self, log_group_name, filter_name, filter_pattern, destination_arn, role_arn
):
self,
log_group_name: str,
filter_name: str,
filter_pattern: str,
destination_arn: str,
role_arn: str,
) -> None:
log_group = self.groups.get(log_group_name)
if not log_group:
@ -932,7 +980,7 @@ class LogsBackend(BaseBackend):
filter_name, filter_pattern, destination_arn, role_arn
)
def delete_subscription_filter(self, log_group_name, filter_name):
def delete_subscription_filter(self, log_group_name: str, filter_name: str) -> None:
log_group = self.groups.get(log_group_name)
if not log_group:
@ -940,22 +988,29 @@ class LogsBackend(BaseBackend):
log_group.delete_subscription_filter(filter_name)
def start_query(self, log_group_names, start_time, end_time, query_string):
def start_query(
self,
log_group_names: List[str],
start_time: str,
end_time: str,
query_string: str,
) -> str:
for log_group_name in log_group_names:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
query_id = mock_random.uuid1()
query_id = str(mock_random.uuid1())
self.queries[query_id] = LogQuery(query_id, start_time, end_time, query_string)
return query_id
def create_export_task(self, log_group_name, destination):
def create_export_task(
self, log_group_name: str, destination: Dict[str, Any]
) -> str:
s3_backends[self.account_id]["global"].get_bucket(destination)
if log_group_name not in self.groups:
raise ResourceNotFoundException()
task_id = mock_random.uuid4()
return task_id
return str(mock_random.uuid4())
logs_backends = BackendDict(LogsBackend, "logs")

View File

@ -1,10 +1,11 @@
import json
import re
from typing import Any, Callable, Optional
from .exceptions import InvalidParameterException
from moto.core.responses import BaseResponse
from .models import logs_backends
from .models import logs_backends, LogsBackend
# See http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/Welcome.html
@ -13,8 +14,12 @@ REGEX_LOG_GROUP_NAME = r"[-._\/#A-Za-z0-9]+"
def validate_param(
param_name, param_value, constraint, constraint_expression, pattern=None
):
param_name: str,
param_value: str,
constraint: str,
constraint_expression: Callable[[str], bool],
pattern: Optional[str] = None,
) -> None:
try:
assert constraint_expression(param_value)
except (AssertionError, TypeError):
@ -33,31 +38,25 @@ def validate_param(
class LogsResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="logs")
@property
def logs_backend(self):
def logs_backend(self) -> LogsBackend:
return logs_backends[self.current_account][self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param_name, if_none=None):
return self.request_params.get(param_name, if_none)
def _get_validated_param(
self, param, constraint, constraint_expression, pattern=None
):
self,
param: str,
constraint: str,
constraint_expression: Callable[[str], bool],
pattern: Optional[str] = None,
) -> Any:
param_value = self._get_param(param)
validate_param(param, param_value, constraint, constraint_expression, pattern)
return param_value
def put_metric_filter(self):
def put_metric_filter(self) -> str:
filter_name = self._get_validated_param(
"filterName",
"Minimum length of 1. Maximum length of 512.",
@ -85,7 +84,7 @@ class LogsResponse(BaseResponse):
return ""
def describe_metric_filters(self):
def describe_metric_filters(self) -> str:
filter_name_prefix = self._get_validated_param(
"filterNamePrefix",
"Minimum length of 1. Maximum length of 512.",
@ -134,7 +133,7 @@ class LogsResponse(BaseResponse):
)
return json.dumps({"metricFilters": filters, "nextToken": next_token})
def delete_metric_filter(self):
def delete_metric_filter(self) -> str:
filter_name = self._get_validated_param(
"filterName",
"Minimum length of 1. Maximum length of 512.",
@ -151,7 +150,7 @@ class LogsResponse(BaseResponse):
self.logs_backend.delete_metric_filter(filter_name, log_group_name)
return ""
def create_log_group(self):
def create_log_group(self) -> str:
log_group_name = self._get_param("logGroupName")
tags = self._get_param("tags")
kms_key_id = self._get_param("kmsKeyId")
@ -159,12 +158,12 @@ class LogsResponse(BaseResponse):
self.logs_backend.create_log_group(log_group_name, tags, kmsKeyId=kms_key_id)
return ""
def delete_log_group(self):
def delete_log_group(self) -> str:
log_group_name = self._get_param("logGroupName")
self.logs_backend.delete_log_group(log_group_name)
return ""
def describe_log_groups(self):
def describe_log_groups(self) -> str:
log_group_name_prefix = self._get_param("logGroupNamePrefix")
next_token = self._get_param("nextToken")
limit = self._get_param("limit", 50)
@ -184,19 +183,19 @@ class LogsResponse(BaseResponse):
result["nextToken"] = next_token
return json.dumps(result)
def create_log_stream(self):
def create_log_stream(self) -> str:
log_group_name = self._get_param("logGroupName")
log_stream_name = self._get_param("logStreamName")
self.logs_backend.create_log_stream(log_group_name, log_stream_name)
return ""
def delete_log_stream(self):
def delete_log_stream(self) -> str:
log_group_name = self._get_param("logGroupName")
log_stream_name = self._get_param("logStreamName")
self.logs_backend.delete_log_stream(log_group_name, log_stream_name)
return ""
def describe_log_streams(self):
def describe_log_streams(self) -> str:
log_group_name = self._get_param("logGroupName")
log_stream_name_prefix = self._get_param("logStreamNamePrefix", "")
descending = self._get_param("descending", False)
@ -214,7 +213,7 @@ class LogsResponse(BaseResponse):
)
return json.dumps({"logStreams": streams, "nextToken": next_token})
def put_log_events(self):
def put_log_events(self) -> str:
log_group_name = self._get_param("logGroupName")
log_stream_name = self._get_param("logStreamName")
log_events = self._get_param("logEvents")
@ -232,7 +231,7 @@ class LogsResponse(BaseResponse):
else:
return json.dumps({"nextSequenceToken": next_sequence_token})
def get_log_events(self):
def get_log_events(self) -> str:
log_group_name = self._get_param("logGroupName")
log_stream_name = self._get_param("logStreamName")
start_time = self._get_param("startTime")
@ -262,7 +261,7 @@ class LogsResponse(BaseResponse):
}
)
def filter_log_events(self):
def filter_log_events(self) -> str:
log_group_name = self._get_param("logGroupName")
log_stream_names = self._get_param("logStreamNames", [])
start_time = self._get_param("startTime")
@ -291,52 +290,50 @@ class LogsResponse(BaseResponse):
}
)
def put_retention_policy(self):
def put_retention_policy(self) -> str:
log_group_name = self._get_param("logGroupName")
retention_in_days = self._get_param("retentionInDays")
self.logs_backend.put_retention_policy(log_group_name, retention_in_days)
return ""
def delete_retention_policy(self):
def delete_retention_policy(self) -> str:
log_group_name = self._get_param("logGroupName")
self.logs_backend.delete_retention_policy(log_group_name)
return ""
def describe_resource_policies(self):
next_token = self._get_param("nextToken")
limit = self._get_param("limit")
policies = self.logs_backend.describe_resource_policies(next_token, limit)
def describe_resource_policies(self) -> str:
policies = self.logs_backend.describe_resource_policies()
return json.dumps({"resourcePolicies": [p.describe() for p in policies]})
def put_resource_policy(self):
def put_resource_policy(self) -> str:
policy_name = self._get_param("policyName")
policy_doc = self._get_param("policyDocument")
policy = self.logs_backend.put_resource_policy(policy_name, policy_doc)
return json.dumps({"resourcePolicy": policy.describe()})
def delete_resource_policy(self):
def delete_resource_policy(self) -> str:
policy_name = self._get_param("policyName")
self.logs_backend.delete_resource_policy(policy_name)
return ""
def list_tags_log_group(self):
def list_tags_log_group(self) -> str:
log_group_name = self._get_param("logGroupName")
tags = self.logs_backend.list_tags_log_group(log_group_name)
return json.dumps({"tags": tags})
def tag_log_group(self):
def tag_log_group(self) -> str:
log_group_name = self._get_param("logGroupName")
tags = self._get_param("tags")
self.logs_backend.tag_log_group(log_group_name, tags)
return ""
def untag_log_group(self):
def untag_log_group(self) -> str:
log_group_name = self._get_param("logGroupName")
tags = self._get_param("tags")
self.logs_backend.untag_log_group(log_group_name, tags)
return ""
def describe_subscription_filters(self):
def describe_subscription_filters(self) -> str:
log_group_name = self._get_param("logGroupName")
subscription_filters = self.logs_backend.describe_subscription_filters(
@ -345,7 +342,7 @@ class LogsResponse(BaseResponse):
return json.dumps({"subscriptionFilters": subscription_filters})
def put_subscription_filter(self):
def put_subscription_filter(self) -> str:
log_group_name = self._get_param("logGroupName")
filter_name = self._get_param("filterName")
filter_pattern = self._get_param("filterPattern")
@ -358,7 +355,7 @@ class LogsResponse(BaseResponse):
return ""
def delete_subscription_filter(self):
def delete_subscription_filter(self) -> str:
log_group_name = self._get_param("logGroupName")
filter_name = self._get_param("filterName")
@ -366,7 +363,7 @@ class LogsResponse(BaseResponse):
return ""
def start_query(self):
def start_query(self) -> str:
log_group_name = self._get_param("logGroupName")
log_group_names = self._get_param("logGroupNames")
start_time = self._get_param("startTime")
@ -385,7 +382,7 @@ class LogsResponse(BaseResponse):
return json.dumps({"queryId": f"{query_id}"})
def create_export_task(self):
def create_export_task(self) -> str:
log_group_name = self._get_param("logGroupName")
destination = self._get_param("destination")
task_id = self.logs_backend.create_export_task(

View File

@ -1,3 +1,6 @@
from typing import Type
PAGINATION_MODEL = {
"describe_log_groups": {
"input_token": "next_token",
@ -16,31 +19,31 @@ PAGINATION_MODEL = {
class FilterPattern:
def __init__(self, term):
def __init__(self, term: str):
self.term = term
class QuotedTermFilterPattern(FilterPattern):
def matches(self, message):
def matches(self, message: str) -> bool:
# We still have the quotes around the term - we should remove those in the parser
return self.term[1:-1] in message
class SingleTermFilterPattern(FilterPattern):
def matches(self, message):
def matches(self, message: str) -> bool:
required_words = self.term.split(" ")
return all([word in message for word in required_words])
class UnsupportedFilterPattern(FilterPattern):
def matches(self, message): # pylint: disable=unused-argument
def matches(self, message: str) -> bool: # pylint: disable=unused-argument
return True
class EventMessageFilter:
def __init__(self, pattern: str):
current_phrase = ""
current_type = None
current_type: Type[FilterPattern] = None # type: ignore
if pattern:
for char in pattern:
if not current_type:
@ -55,5 +58,5 @@ class EventMessageFilter:
current_type = UnsupportedFilterPattern
self.filter_type = current_type(current_phrase)
def matches(self, message):
return self.filter_type.matches(message)
def matches(self, message: str) -> bool:
return self.filter_type.matches(message) # type: ignore

View File

@ -235,7 +235,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/moto_api,moto/neptune
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/moto_api,moto/neptune
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract