Logs: added support for destinations mocking (#6487)

This commit is contained in:
Macwan Nevil 2023-07-14 18:23:44 +05:30 committed by GitHub
parent e0ceec9e48
commit 1098d4557f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 208 additions and 0 deletions

View File

@ -18,6 +18,34 @@ from .utils import PAGINATION_MODEL, EventMessageFilter
MAX_RESOURCE_POLICIES_PER_REGION = 10
class Destination(BaseModel):
def __init__(
self,
account_id: str,
region: str,
destination_name: str,
role_arn: str,
target_arn: str,
access_policy: Optional[str] = None,
):
self.access_policy = access_policy
self.arn = f"arn:aws:logs:{region}:{account_id}:destination:{destination_name}"
self.creation_time = int(unix_time_millis())
self.destination_name = destination_name
self.role_arn = role_arn
self.target_arn = target_arn
def to_dict(self) -> Dict[str, Any]:
return {
"accessPolicy": self.access_policy,
"arn": self.arn,
"creationTime": self.creation_time,
"destinationName": self.destination_name,
"roleArn": self.role_arn,
"targetArn": self.target_arn,
}
class LogQuery(BaseModel):
def __init__(self, query_id: str, start_time: str, end_time: str, query: str):
self.query_id = query_id
@ -655,6 +683,7 @@ class LogsBackend(BaseBackend):
self.filters = MetricFilters()
self.queries: Dict[str, LogQuery] = dict()
self.resource_policies: Dict[str, LogResourcePolicy] = dict()
self.destinations: Dict[str, Destination] = dict()
@staticmethod
def default_vpc_endpoint_service(
@ -706,6 +735,60 @@ class LogsBackend(BaseBackend):
return groups
def get_destination(self, destination_name: str) -> Destination:
for destination in self.destinations:
if self.destinations[destination].destination_name == destination_name:
return self.destinations[destination]
raise ResourceNotFoundException()
def put_destination(
self, destination_name: str, role_arn: str, target_arn: str
) -> Destination:
for _, destination in self.destinations.items():
if destination.destination_name == destination_name:
if role_arn:
destination.role_arn = role_arn
if target_arn:
destination.target_arn = target_arn
return destination
destination = Destination(
self.account_id, self.region_name, destination_name, role_arn, target_arn
)
self.destinations[destination.arn] = destination
return destination
def delete_destination(self, destination_name: str) -> None:
destination = self.get_destination(destination_name)
self.destinations.pop(destination.arn)
return
def describe_destinations(
self, destination_name_prefix: str, limit: int, next_token: Optional[int] = None
) -> Tuple[List[Dict[str, Any]], Optional[int]]:
if limit > 50:
raise InvalidParameterException(
constraint="Member must have value less than or equal to 50",
parameter="limit",
value=limit,
)
result = []
for destination in self.destinations:
result.append(self.destinations[destination].to_dict())
if next_token:
result = result[: int(next_token)]
result = [
destination
for destination in result
if destination["destinationName"].startswith(destination_name_prefix)
]
return result, next_token
def put_destination_policy(self, destination_name: str, access_policy: str) -> None:
destination = self.get_destination(destination_name)
destination.access_policy = access_policy
return
def create_log_stream(self, log_group_name: str, log_stream_name: str) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()

View File

@ -183,6 +183,41 @@ class LogsResponse(BaseResponse):
result["nextToken"] = next_token
return json.dumps(result)
def put_destination(self) -> str:
destination_name = self._get_param("destinationName")
role_arn = self._get_param("roleArn")
target_arn = self._get_param("targetArn")
destination = self.logs_backend.put_destination(
destination_name, role_arn, target_arn
)
result = {"destination": destination.to_dict()}
return json.dumps(result)
def delete_destination(self) -> str:
destination_name = self._get_param("destinationName")
self.logs_backend.delete_destination(destination_name)
return ""
def describe_destinations(self) -> str:
destination_name_prefix = self._get_param("DestinationNamePrefix")
limit = self._get_param("limit", 50)
next_token = self._get_param("nextToken")
destinations, next_token = self.logs_backend.describe_destinations(
destination_name_prefix, int(limit), next_token
)
result = {"destinations": destinations, "nextToken": next_token}
return json.dumps(result)
def put_destination_policy(self) -> str:
access_policy = self._get_param("accessPolicy")
destination_name = self._get_param("destinationName")
self.logs_backend.put_destination_policy(destination_name, access_policy)
return ""
def create_log_stream(self) -> str:
log_group_name = self._get_param("logGroupName")
log_stream_name = self._get_param("logStreamName")

View File

@ -35,6 +35,20 @@ json_policy_doc = json.dumps(
}
)
access_policy_doc = json.dumps(
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"AWS": "logs.us-east-1.amazonaws.com"},
"Action": "logs:PutSubscriptionFilter",
"Resource": "destination_arn",
}
],
}
)
@mock_logs
def test_describe_metric_filters_happy_prefix():
@ -255,6 +269,82 @@ def test_delete_metric_filter_invalid_log_group_name(
response["Error"]["Message"].should.contain(failing_constraint)
@mock_logs
def test_destinations():
conn = boto3.client("logs", "us-west-2")
destination_name = "test-destination"
role_arn = "arn:aws:iam::123456789012:role/my-subscription-role"
target_arn = "arn:aws:kinesis:us-east-1:123456789012:stream/my-kinesis-stream"
role_arn_updated = "arn:aws:iam::123456789012:role/my-subscription-role-updated"
target_arn_updated = (
"arn:aws:kinesis:us-east-1:123456789012:stream/my-kinesis-stream-updated"
)
response = conn.describe_destinations(DestinationNamePrefix=destination_name)
assert len(response["destinations"]) == 0
response = conn.put_destination(
destinationName=destination_name,
targetArn=target_arn,
roleArn=role_arn,
tags={"Name": destination_name},
)
assert response["destination"]["destinationName"] == destination_name
assert response["destination"]["targetArn"] == target_arn
assert response["destination"]["roleArn"] == role_arn
response = conn.describe_destinations(DestinationNamePrefix=destination_name)
assert len(response["destinations"]) == 1
assert response["destinations"][0]["destinationName"] == destination_name
assert response["destinations"][0]["targetArn"] == target_arn
assert response["destinations"][0]["roleArn"] == role_arn
response = conn.put_destination(
destinationName=destination_name,
targetArn=target_arn_updated,
roleArn=role_arn_updated,
tags={"Name": destination_name},
)
assert response["destination"]["destinationName"] == destination_name
assert response["destination"]["targetArn"] == target_arn_updated
assert response["destination"]["roleArn"] == role_arn_updated
response = conn.describe_destinations(DestinationNamePrefix=destination_name)
assert len(response["destinations"]) == 1
assert response["destinations"][0]["destinationName"] == destination_name
assert response["destinations"][0]["targetArn"] == target_arn_updated
assert response["destinations"][0]["roleArn"] == role_arn_updated
response = conn.put_destination_policy(
destinationName=destination_name, accessPolicy=access_policy_doc
)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = conn.describe_destinations(DestinationNamePrefix=destination_name)
assert response["destinations"][0]["accessPolicy"] == access_policy_doc
conn.put_destination(
destinationName=f"{destination_name}-1",
targetArn=target_arn,
roleArn=role_arn,
tags={"Name": destination_name},
)
response = conn.describe_destinations(DestinationNamePrefix=destination_name)
assert len(response["destinations"]) == 2
response = conn.describe_destinations(DestinationNamePrefix=f"{destination_name}-1")
assert len(response["destinations"]) == 1
response = conn.delete_destination(destinationName=f"{destination_name}-1")
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = conn.describe_destinations(DestinationNamePrefix=destination_name)
assert len(response["destinations"]) == 1
response = conn.delete_destination(destinationName=destination_name)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
def put_metric_filter(conn, count=1):
count = str(count)
return conn.put_metric_filter(