Techdebt: MyPy M (#6170)

This commit is contained in:
Bert Blommers 2023-04-03 23:50:19 +01:00 committed by GitHub
parent 52870d6114
commit 706ff9f5e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 752 additions and 746 deletions

View File

@ -4,5 +4,5 @@ from moto.core.exceptions import JsonRESTError
class NotFoundException(JsonRESTError):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("NotFoundException", message)

View File

@ -1,12 +1,15 @@
from collections import OrderedDict
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.mediaconnect.exceptions import NotFoundException
from moto.moto_api._internal import mock_random as random
from moto.utilities.tagging_service import TaggingService
class Flow(BaseModel):
def __init__(self, **kwargs):
def __init__(self, account_id: str, region_name: str, **kwargs: Any):
self.id = random.uuid4().hex
self.availability_zone = kwargs.get("availability_zone")
self.entitlements = kwargs.get("entitlements", [])
self.name = kwargs.get("name")
@ -15,17 +18,19 @@ class Flow(BaseModel):
self.source_failover_config = kwargs.get("source_failover_config", {})
self.sources = kwargs.get("sources", [])
self.vpc_interfaces = kwargs.get("vpc_interfaces", [])
self.status = "STANDBY" # one of 'STANDBY'|'ACTIVE'|'UPDATING'|'DELETING'|'STARTING'|'STOPPING'|'ERROR'
self._previous_status = None
self.description = None
self.flow_arn = None
self.egress_ip = None
self.status: Optional[
str
] = "STANDBY" # one of 'STANDBY'|'ACTIVE'|'UPDATING'|'DELETING'|'STARTING'|'STOPPING'|'ERROR'
self._previous_status: Optional[str] = None
self.description = "A Moto test flow"
self.flow_arn = f"arn:aws:mediaconnect:{region_name}:{account_id}:flow:{self.id}:{self.name}"
self.egress_ip = "127.0.0.1"
if self.source and not self.sources:
self.sources = [
self.source,
]
def to_dict(self, include=None):
def to_dict(self, include: Optional[List[str]] = None) -> Dict[str, Any]:
data = {
"availabilityZone": self.availability_zone,
"description": self.description,
@ -47,7 +52,7 @@ class Flow(BaseModel):
return new_data
return data
def resolve_transient_states(self):
def resolve_transient_states(self) -> None:
if self.status in ["STARTING"]:
self.status = "ACTIVE"
if self.status in ["STOPPING"]:
@ -57,26 +62,18 @@ class Flow(BaseModel):
self._previous_status = None
class Resource(BaseModel):
def __init__(self, **kwargs):
self.resource_arn = kwargs.get("resource_arn")
self.tags = OrderedDict()
def to_dict(self):
data = {
"resourceArn": self.resource_arn,
"tags": self.tags,
}
return data
class MediaConnectBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self._flows = OrderedDict()
self._resources = OrderedDict()
self._flows: Dict[str, Flow] = OrderedDict()
self.tagger = TaggingService()
def _add_source_details(self, source, flow_id, ingest_ip="127.0.0.1"):
def _add_source_details(
self,
source: Optional[Dict[str, Any]],
flow_id: str,
ingest_ip: str = "127.0.0.1",
) -> None:
if source:
source["sourceArn"] = (
f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source"
@ -85,7 +82,9 @@ class MediaConnectBackend(BaseBackend):
if not source.get("entitlementArn"):
source["ingestIp"] = ingest_ip
def _add_entitlement_details(self, entitlement, entitlement_id):
def _add_entitlement_details(
self, entitlement: Optional[Dict[str, Any]], entitlement_id: str
) -> None:
if entitlement:
entitlement["entitlementArn"] = (
f"arn:aws:mediaconnect:{self.region_name}"
@ -93,15 +92,9 @@ class MediaConnectBackend(BaseBackend):
f":{entitlement['name']}"
)
def _create_flow_add_details(self, flow):
flow_id = random.uuid4().hex
flow.description = "A Moto test flow"
flow.egress_ip = "127.0.0.1"
flow.flow_arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:flow:{flow_id}:{flow.name}"
def _create_flow_add_details(self, flow: Flow) -> None:
for index, _source in enumerate(flow.sources):
self._add_source_details(_source, flow_id, f"127.0.0.{index}")
self._add_source_details(_source, flow.id, f"127.0.0.{index}")
for index, output in enumerate(flow.outputs or []):
if output.get("protocol") in ["srt-listener", "zixi-pull"]:
@ -119,16 +112,18 @@ class MediaConnectBackend(BaseBackend):
def create_flow(
self,
availability_zone,
entitlements,
name,
outputs,
source,
source_failover_config,
sources,
vpc_interfaces,
):
availability_zone: str,
entitlements: List[Dict[str, Any]],
name: str,
outputs: List[Dict[str, Any]],
source: Dict[str, Any],
source_failover_config: Dict[str, Any],
sources: List[Dict[str, Any]],
vpc_interfaces: List[Dict[str, Any]],
) -> Flow:
flow = Flow(
account_id=self.account_id,
region_name=self.region_name,
availability_zone=availability_zone,
entitlements=entitlements,
name=name,
@ -142,11 +137,14 @@ class MediaConnectBackend(BaseBackend):
self._flows[flow.flow_arn] = flow
return flow
def list_flows(self, max_results, next_token):
def list_flows(self, max_results: Optional[int]) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
flows = list(self._flows.values())
if max_results is not None:
flows = flows[:max_results]
response_flows = [
return [
fl.to_dict(
include=[
"availabilityZone",
@ -159,74 +157,59 @@ class MediaConnectBackend(BaseBackend):
)
for fl in flows
]
return response_flows, next_token
def describe_flow(self, flow_arn=None):
messages = {}
def describe_flow(self, flow_arn: str) -> Flow:
if flow_arn in self._flows:
flow = self._flows[flow_arn]
flow.resolve_transient_states()
else:
raise NotFoundException(message="Flow not found.")
return flow.to_dict(), messages
return flow
raise NotFoundException(message="Flow not found.")
def delete_flow(self, flow_arn):
def delete_flow(self, flow_arn: str) -> Flow:
if flow_arn in self._flows:
flow = self._flows[flow_arn]
del self._flows[flow_arn]
else:
raise NotFoundException(message="Flow not found.")
return flow_arn, flow.status
return self._flows.pop(flow_arn)
raise NotFoundException(message="Flow not found.")
def start_flow(self, flow_arn):
def start_flow(self, flow_arn: str) -> Flow:
if flow_arn in self._flows:
flow = self._flows[flow_arn]
flow.status = "STARTING"
else:
raise NotFoundException(message="Flow not found.")
return flow_arn, flow.status
return flow
raise NotFoundException(message="Flow not found.")
def stop_flow(self, flow_arn):
def stop_flow(self, flow_arn: str) -> Flow:
if flow_arn in self._flows:
flow = self._flows[flow_arn]
flow.status = "STOPPING"
else:
raise NotFoundException(message="Flow not found.")
return flow_arn, flow.status
return flow
raise NotFoundException(message="Flow not found.")
def tag_resource(self, resource_arn, tags):
if resource_arn in self._resources:
resource = self._resources[resource_arn]
else:
resource = Resource(resource_arn=resource_arn)
resource.tags.update(tags)
self._resources[resource_arn] = resource
return None
def tag_resource(self, resource_arn: str, tags: Dict[str, Any]) -> None:
tag_list = TaggingService.convert_dict_to_tags_input(tags)
self.tagger.tag_resource(resource_arn, tag_list)
def list_tags_for_resource(self, resource_arn):
if resource_arn in self._resources:
resource = self._resources[resource_arn]
else:
raise NotFoundException(message="Resource not found.")
return resource.tags
def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]:
if self.tagger.has_tags(resource_arn):
return self.tagger.get_tag_dict_for_resource(resource_arn)
raise NotFoundException(message="Resource not found.")
def add_flow_vpc_interfaces(self, flow_arn, vpc_interfaces):
def add_flow_vpc_interfaces(
self, flow_arn: str, vpc_interfaces: List[Dict[str, Any]]
) -> Flow:
if flow_arn in self._flows:
flow = self._flows[flow_arn]
flow.vpc_interfaces = vpc_interfaces
else:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
return flow_arn, flow.vpc_interfaces
return flow
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
def add_flow_outputs(self, flow_arn, outputs):
def add_flow_outputs(self, flow_arn: str, outputs: List[Dict[str, Any]]) -> Flow:
if flow_arn in self._flows:
flow = self._flows[flow_arn]
flow.outputs = outputs
else:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
return flow_arn, flow.outputs
return flow
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
def remove_flow_vpc_interface(self, flow_arn, vpc_interface_name):
def remove_flow_vpc_interface(self, flow_arn: str, vpc_interface_name: str) -> None:
if flow_arn in self._flows:
flow = self._flows[flow_arn]
flow.vpc_interfaces = [
@ -236,9 +219,8 @@ class MediaConnectBackend(BaseBackend):
]
else:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
return flow_arn, vpc_interface_name
def remove_flow_output(self, flow_arn, output_name):
def remove_flow_output(self, flow_arn: str, output_name: str) -> None:
if flow_arn in self._flows:
flow = self._flows[flow_arn]
flow.outputs = [
@ -248,28 +230,27 @@ class MediaConnectBackend(BaseBackend):
]
else:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
return flow_arn, output_name
def update_flow_output(
self,
flow_arn,
output_arn,
cidr_allow_list,
description,
destination,
encryption,
max_latency,
media_stream_output_configuration,
min_latency,
port,
protocol,
remote_id,
sender_control_port,
sender_ip_address,
smoothing_latency,
stream_id,
vpc_interface_attachment,
):
flow_arn: str,
output_arn: str,
cidr_allow_list: List[str],
description: str,
destination: str,
encryption: Dict[str, str],
max_latency: int,
media_stream_output_configuration: List[Dict[str, Any]],
min_latency: int,
port: int,
protocol: str,
remote_id: str,
sender_control_port: int,
sender_ip_address: str,
smoothing_latency: int,
stream_id: str,
vpc_interface_attachment: Dict[str, str],
) -> Dict[str, Any]:
if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn]
@ -292,10 +273,12 @@ class MediaConnectBackend(BaseBackend):
output["smoothingLatency"] = smoothing_latency
output["streamId"] = stream_id
output["vpcInterfaceAttachment"] = vpc_interface_attachment
return flow_arn, output
return output
raise NotFoundException(message=f"output with arn={output_arn} not found")
def add_flow_sources(self, flow_arn, sources):
def add_flow_sources(
self, flow_arn: str, sources: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn]
@ -305,32 +288,32 @@ class MediaConnectBackend(BaseBackend):
arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source:{source_id}:{name}"
source["sourceArn"] = arn
flow.sources = sources
return flow_arn, sources
return sources
def update_flow_source(
self,
flow_arn,
source_arn,
decryption,
description,
entitlement_arn,
ingest_port,
max_bitrate,
max_latency,
max_sync_buffer,
media_stream_source_configurations,
min_latency,
protocol,
sender_control_port,
sender_ip_address,
stream_id,
vpc_interface_name,
whitelist_cidr,
):
flow_arn: str,
source_arn: str,
decryption: str,
description: str,
entitlement_arn: str,
ingest_port: int,
max_bitrate: int,
max_latency: int,
max_sync_buffer: int,
media_stream_source_configurations: List[Dict[str, Any]],
min_latency: int,
protocol: str,
sender_control_port: int,
sender_ip_address: str,
stream_id: str,
vpc_interface_name: str,
whitelist_cidr: str,
) -> Optional[Dict[str, Any]]:
if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn]
source = next(
source: Optional[Dict[str, Any]] = next(
iter(
[source for source in flow.sources if source["sourceArn"] == source_arn]
),
@ -354,13 +337,13 @@ class MediaConnectBackend(BaseBackend):
source["streamId"] = stream_id
source["vpcInterfaceName"] = vpc_interface_name
source["whitelistCidr"] = whitelist_cidr
return flow_arn, source
return source
def grant_flow_entitlements(
self,
flow_arn,
entitlements,
):
flow_arn: str,
entitlements: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn]
@ -371,30 +354,30 @@ class MediaConnectBackend(BaseBackend):
entitlement["entitlementArn"] = arn
flow.entitlements += entitlements
return flow_arn, entitlements
return entitlements
def revoke_flow_entitlement(self, flow_arn, entitlement_arn):
def revoke_flow_entitlement(self, flow_arn: str, entitlement_arn: str) -> None:
if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn]
for entitlement in flow.entitlements:
if entitlement_arn == entitlement["entitlementArn"]:
flow.entitlements.remove(entitlement)
return flow_arn, entitlement_arn
return
raise NotFoundException(
message=f"entitlement with arn={entitlement_arn} not found"
)
def update_flow_entitlement(
self,
flow_arn,
entitlement_arn,
description,
encryption,
entitlement_status,
name,
subscribers,
):
flow_arn: str,
entitlement_arn: str,
description: str,
encryption: Dict[str, str],
entitlement_status: str,
name: str,
subscribers: List[str],
) -> Dict[str, Any]:
if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn]
@ -405,12 +388,10 @@ class MediaConnectBackend(BaseBackend):
entitlement["entitlementStatus"] = entitlement_status
entitlement["name"] = name
entitlement["subscribers"] = subscribers
return flow_arn, entitlement
return entitlement
raise NotFoundException(
message=f"entitlement with arn={entitlement_arn} not found"
)
# add methods from here
mediaconnect_backends = BackendDict(MediaConnectBackend, "mediaconnect")

View File

@ -1,20 +1,20 @@
import json
from moto.core.responses import BaseResponse
from .models import mediaconnect_backends
from .models import mediaconnect_backends, MediaConnectBackend
from urllib.parse import unquote
class MediaConnectResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="mediaconnect")
@property
def mediaconnect_backend(self):
def mediaconnect_backend(self) -> MediaConnectBackend:
return mediaconnect_backends[self.current_account][self.region]
def create_flow(self):
def create_flow(self) -> str:
availability_zone = self._get_param("availabilityZone")
entitlements = self._get_param("entitlements")
name = self._get_param("name")
@ -35,85 +35,79 @@ class MediaConnectResponse(BaseResponse):
)
return json.dumps(dict(flow=flow.to_dict()))
def list_flows(self):
def list_flows(self) -> str:
max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken")
flows, next_token = self.mediaconnect_backend.list_flows(
max_results=max_results, next_token=next_token
)
return json.dumps(dict(flows=flows, nextToken=next_token))
flows = self.mediaconnect_backend.list_flows(max_results=max_results)
return json.dumps(dict(flows=flows))
def describe_flow(self):
def describe_flow(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
flow, messages = self.mediaconnect_backend.describe_flow(flow_arn=flow_arn)
return json.dumps(dict(flow=flow, messages=messages))
flow = self.mediaconnect_backend.describe_flow(flow_arn=flow_arn)
return json.dumps(dict(flow=flow.to_dict()))
def delete_flow(self):
def delete_flow(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
flow_arn, status = self.mediaconnect_backend.delete_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow_arn, status=status))
flow = self.mediaconnect_backend.delete_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status))
def start_flow(self):
def start_flow(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
flow_arn, status = self.mediaconnect_backend.start_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow_arn, status=status))
flow = self.mediaconnect_backend.start_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status))
def stop_flow(self):
def stop_flow(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
flow_arn, status = self.mediaconnect_backend.stop_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow_arn, status=status))
flow = self.mediaconnect_backend.stop_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status))
def tag_resource(self):
def tag_resource(self) -> str:
resource_arn = unquote(self._get_param("resourceArn"))
tags = self._get_param("tags")
self.mediaconnect_backend.tag_resource(resource_arn=resource_arn, tags=tags)
return json.dumps(dict())
def list_tags_for_resource(self):
def list_tags_for_resource(self) -> str:
resource_arn = unquote(self._get_param("resourceArn"))
tags = self.mediaconnect_backend.list_tags_for_resource(
resource_arn=resource_arn
)
tags = self.mediaconnect_backend.list_tags_for_resource(resource_arn)
return json.dumps(dict(tags=tags))
def add_flow_vpc_interfaces(self):
def add_flow_vpc_interfaces(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
vpc_interfaces = self._get_param("vpcInterfaces")
flow_arn, vpc_interfaces = self.mediaconnect_backend.add_flow_vpc_interfaces(
flow = self.mediaconnect_backend.add_flow_vpc_interfaces(
flow_arn=flow_arn, vpc_interfaces=vpc_interfaces
)
return json.dumps(dict(flow_arn=flow_arn, vpc_interfaces=vpc_interfaces))
return json.dumps(
dict(flow_arn=flow.flow_arn, vpc_interfaces=flow.vpc_interfaces)
)
def remove_flow_vpc_interface(self):
def remove_flow_vpc_interface(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
vpc_interface_name = unquote(self._get_param("vpcInterfaceName"))
(
flow_arn,
vpc_interface_name,
) = self.mediaconnect_backend.remove_flow_vpc_interface(
self.mediaconnect_backend.remove_flow_vpc_interface(
flow_arn=flow_arn, vpc_interface_name=vpc_interface_name
)
return json.dumps(
dict(flow_arn=flow_arn, vpc_interface_name=vpc_interface_name)
)
def add_flow_outputs(self):
def add_flow_outputs(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
outputs = self._get_param("outputs")
flow_arn, outputs = self.mediaconnect_backend.add_flow_outputs(
flow = self.mediaconnect_backend.add_flow_outputs(
flow_arn=flow_arn, outputs=outputs
)
return json.dumps(dict(flow_arn=flow_arn, outputs=outputs))
return json.dumps(dict(flow_arn=flow.flow_arn, outputs=flow.outputs))
def remove_flow_output(self):
def remove_flow_output(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
output_name = unquote(self._get_param("outputArn"))
flow_arn, output_name = self.mediaconnect_backend.remove_flow_output(
self.mediaconnect_backend.remove_flow_output(
flow_arn=flow_arn, output_name=output_name
)
return json.dumps(dict(flow_arn=flow_arn, output_name=output_name))
def update_flow_output(self):
def update_flow_output(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
output_arn = unquote(self._get_param("outputArn"))
cidr_allow_list = self._get_param("cidrAllowList")
@ -133,7 +127,7 @@ class MediaConnectResponse(BaseResponse):
smoothing_latency = self._get_param("smoothingLatency")
stream_id = self._get_param("streamId")
vpc_interface_attachment = self._get_param("vpcInterfaceAttachment")
flow_arn, output = self.mediaconnect_backend.update_flow_output(
output = self.mediaconnect_backend.update_flow_output(
flow_arn=flow_arn,
output_arn=output_arn,
cidr_allow_list=cidr_allow_list,
@ -154,15 +148,15 @@ class MediaConnectResponse(BaseResponse):
)
return json.dumps(dict(flowArn=flow_arn, output=output))
def add_flow_sources(self):
def add_flow_sources(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
sources = self._get_param("sources")
flow_arn, sources = self.mediaconnect_backend.add_flow_sources(
sources = self.mediaconnect_backend.add_flow_sources(
flow_arn=flow_arn, sources=sources
)
return json.dumps(dict(flow_arn=flow_arn, sources=sources))
def update_flow_source(self):
def update_flow_source(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
source_arn = unquote(self._get_param("sourceArn"))
description = self._get_param("description")
@ -182,7 +176,7 @@ class MediaConnectResponse(BaseResponse):
stream_id = self._get_param("streamId")
vpc_interface_name = self._get_param("vpcInterfaceName")
whitelist_cidr = self._get_param("whitelistCidr")
flow_arn, source = self.mediaconnect_backend.update_flow_source(
source = self.mediaconnect_backend.update_flow_source(
flow_arn=flow_arn,
source_arn=source_arn,
decryption=decryption,
@ -203,23 +197,23 @@ class MediaConnectResponse(BaseResponse):
)
return json.dumps(dict(flow_arn=flow_arn, source=source))
def grant_flow_entitlements(self):
def grant_flow_entitlements(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
entitlements = self._get_param("entitlements")
flow_arn, entitlements = self.mediaconnect_backend.grant_flow_entitlements(
entitlements = self.mediaconnect_backend.grant_flow_entitlements(
flow_arn=flow_arn, entitlements=entitlements
)
return json.dumps(dict(flow_arn=flow_arn, entitlements=entitlements))
def revoke_flow_entitlement(self):
def revoke_flow_entitlement(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
entitlement_arn = unquote(self._get_param("entitlementArn"))
flow_arn, entitlement_arn = self.mediaconnect_backend.revoke_flow_entitlement(
self.mediaconnect_backend.revoke_flow_entitlement(
flow_arn=flow_arn, entitlement_arn=entitlement_arn
)
return json.dumps(dict(flowArn=flow_arn, entitlementArn=entitlement_arn))
def update_flow_entitlement(self):
def update_flow_entitlement(self) -> str:
flow_arn = unquote(self._get_param("flowArn"))
entitlement_arn = unquote(self._get_param("entitlementArn"))
description = self._get_param("description")
@ -227,7 +221,7 @@ class MediaConnectResponse(BaseResponse):
entitlement_status = self._get_param("entitlementStatus")
name = self._get_param("name")
subscribers = self._get_param("subscribers")
flow_arn, entitlement = self.mediaconnect_backend.update_flow_entitlement(
entitlement = self.mediaconnect_backend.update_flow_entitlement(
flow_arn=flow_arn,
entitlement_arn=entitlement_arn,
description=description,

View File

@ -1,11 +1,12 @@
from collections import OrderedDict
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random
class Input(BaseModel):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn")
self.attached_channels = kwargs.get("attached_channels", [])
self.destinations = kwargs.get("destinations", [])
@ -23,8 +24,8 @@ class Input(BaseModel):
self.tags = kwargs.get("tags")
self.input_type = kwargs.get("input_type")
def to_dict(self):
data = {
def to_dict(self) -> Dict[str, Any]:
return {
"arn": self.arn,
"attachedChannels": self.attached_channels,
"destinations": self.destinations,
@ -41,9 +42,8 @@ class Input(BaseModel):
"tags": self.tags,
"type": self.input_type,
}
return data
def _resolve_transient_states(self):
def _resolve_transient_states(self) -> None:
# Resolve transient states before second call
# (to simulate AWS taking its sweet time with these things)
if self.state in ["CREATING"]:
@ -53,7 +53,7 @@ class Input(BaseModel):
class Channel(BaseModel):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn")
self.cdi_input_specification = kwargs.get("cdi_input_specification")
self.channel_class = kwargs.get("channel_class", "STANDARD")
@ -71,7 +71,7 @@ class Channel(BaseModel):
self.tags = kwargs.get("tags")
self._previous_state = None
def to_dict(self, exclude=None):
def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]:
data = {
"arn": self.arn,
"cdiInputSpecification": self.cdi_input_specification,
@ -97,7 +97,7 @@ class Channel(BaseModel):
del data[key]
return data
def _resolve_transient_states(self):
def _resolve_transient_states(self) -> None:
# Resolve transient states before second call
# (to simulate AWS taking its sweet time with these things)
if self.state in ["CREATING", "STOPPING"]:
@ -112,24 +112,24 @@ class Channel(BaseModel):
class MediaLiveBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self._channels = OrderedDict()
self._inputs = OrderedDict()
self._channels: Dict[str, Channel] = OrderedDict()
self._inputs: Dict[str, Input] = OrderedDict()
def create_channel(
self,
cdi_input_specification,
channel_class,
destinations,
encoder_settings,
input_attachments,
input_specification,
log_level,
name,
role_arn,
tags,
):
cdi_input_specification: Dict[str, Any],
channel_class: str,
destinations: List[Dict[str, Any]],
encoder_settings: Dict[str, Any],
input_attachments: List[Dict[str, Any]],
input_specification: Dict[str, str],
log_level: str,
name: str,
role_arn: str,
tags: Dict[str, str],
) -> Channel:
"""
The RequestID and Reserved parameters are not yet implemented
"""
@ -155,47 +155,49 @@ class MediaLiveBackend(BaseBackend):
self._channels[channel_id] = channel
return channel
def list_channels(self, max_results, next_token):
def list_channels(self, max_results: Optional[int]) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
channels = list(self._channels.values())
if max_results is not None:
channels = channels[:max_results]
response_channels = [
return [
c.to_dict(exclude=["encoderSettings", "pipelineDetails"]) for c in channels
]
return response_channels, next_token
def describe_channel(self, channel_id):
def describe_channel(self, channel_id: str) -> Channel:
channel = self._channels[channel_id]
channel._resolve_transient_states()
return channel.to_dict()
return channel
def delete_channel(self, channel_id):
def delete_channel(self, channel_id: str) -> Channel:
channel = self._channels[channel_id]
channel.state = "DELETING"
return channel.to_dict()
return channel
def start_channel(self, channel_id):
def start_channel(self, channel_id: str) -> Channel:
channel = self._channels[channel_id]
channel.state = "STARTING"
return channel.to_dict()
return channel
def stop_channel(self, channel_id):
def stop_channel(self, channel_id: str) -> Channel:
channel = self._channels[channel_id]
channel.state = "STOPPING"
return channel.to_dict()
return channel
def update_channel(
self,
channel_id,
cdi_input_specification,
destinations,
encoder_settings,
input_attachments,
input_specification,
log_level,
name,
role_arn,
):
channel_id: str,
cdi_input_specification: Dict[str, str],
destinations: List[Dict[str, Any]],
encoder_settings: Dict[str, Any],
input_attachments: List[Dict[str, Any]],
input_specification: Dict[str, str],
log_level: str,
name: str,
role_arn: str,
) -> Channel:
channel = self._channels[channel_id]
channel.cdi_input_specification = cdi_input_specification
channel.destinations = destinations
@ -214,16 +216,16 @@ class MediaLiveBackend(BaseBackend):
def create_input(
self,
destinations,
input_devices,
input_security_groups,
media_connect_flows,
name,
role_arn,
sources,
tags,
input_type,
):
destinations: List[Dict[str, str]],
input_devices: List[Dict[str, str]],
input_security_groups: List[str],
media_connect_flows: List[Dict[str, str]],
name: str,
role_arn: str,
sources: List[Dict[str, str]],
tags: Dict[str, str],
input_type: str,
) -> Input:
"""
The VPC and RequestId parameters are not yet implemented
"""
@ -246,34 +248,35 @@ class MediaLiveBackend(BaseBackend):
self._inputs[input_id] = a_input
return a_input
def describe_input(self, input_id):
def describe_input(self, input_id: str) -> Input:
a_input = self._inputs[input_id]
a_input._resolve_transient_states()
return a_input.to_dict()
return a_input
def list_inputs(self, max_results, next_token):
def list_inputs(self, max_results: Optional[int]) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
inputs = list(self._inputs.values())
if max_results is not None:
inputs = inputs[:max_results]
response_inputs = [i.to_dict() for i in inputs]
return response_inputs, next_token
return [i.to_dict() for i in inputs]
def delete_input(self, input_id):
def delete_input(self, input_id: str) -> None:
a_input = self._inputs[input_id]
a_input.state = "DELETING"
return a_input.to_dict()
def update_input(
self,
destinations,
input_devices,
input_id,
input_security_groups,
media_connect_flows,
name,
role_arn,
sources,
):
destinations: List[Dict[str, str]],
input_devices: List[Dict[str, str]],
input_id: str,
input_security_groups: List[str],
media_connect_flows: List[Dict[str, str]],
name: str,
role_arn: str,
sources: List[Dict[str, str]],
) -> Input:
a_input = self._inputs[input_id]
a_input.destinations = destinations
a_input.input_devices = input_devices

View File

@ -1,17 +1,17 @@
from moto.core.responses import BaseResponse
from .models import medialive_backends
from .models import medialive_backends, MediaLiveBackend
import json
class MediaLiveResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="medialive")
@property
def medialive_backend(self):
def medialive_backend(self) -> MediaLiveBackend:
return medialive_backends[self.current_account][self.region]
def create_channel(self):
def create_channel(self) -> str:
cdi_input_specification = self._get_param("cdiInputSpecification")
channel_class = self._get_param("channelClass")
destinations = self._get_param("destinations")
@ -39,34 +39,33 @@ class MediaLiveResponse(BaseResponse):
dict(channel=channel.to_dict(exclude=["pipelinesRunningCount"]))
)
def list_channels(self):
def list_channels(self) -> str:
max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken")
channels, next_token = self.medialive_backend.list_channels(
max_results=max_results, next_token=next_token
)
channels = self.medialive_backend.list_channels(max_results=max_results)
return json.dumps(dict(channels=channels, nextToken=next_token))
return json.dumps(dict(channels=channels, nextToken=None))
def describe_channel(self):
def describe_channel(self) -> str:
channel_id = self._get_param("channelId")
return json.dumps(
self.medialive_backend.describe_channel(channel_id=channel_id)
)
channel = self.medialive_backend.describe_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def delete_channel(self):
def delete_channel(self) -> str:
channel_id = self._get_param("channelId")
return json.dumps(self.medialive_backend.delete_channel(channel_id=channel_id))
channel = self.medialive_backend.delete_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def start_channel(self):
def start_channel(self) -> str:
channel_id = self._get_param("channelId")
return json.dumps(self.medialive_backend.start_channel(channel_id=channel_id))
channel = self.medialive_backend.start_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def stop_channel(self):
def stop_channel(self) -> str:
channel_id = self._get_param("channelId")
return json.dumps(self.medialive_backend.stop_channel(channel_id=channel_id))
channel = self.medialive_backend.stop_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def update_channel(self):
def update_channel(self) -> str:
channel_id = self._get_param("channelId")
cdi_input_specification = self._get_param("cdiInputSpecification")
destinations = self._get_param("destinations")
@ -89,7 +88,7 @@ class MediaLiveResponse(BaseResponse):
)
return json.dumps(dict(channel=channel.to_dict()))
def create_input(self):
def create_input(self) -> str:
destinations = self._get_param("destinations")
input_devices = self._get_param("inputDevices")
input_security_groups = self._get_param("inputSecurityGroups")
@ -112,25 +111,23 @@ class MediaLiveResponse(BaseResponse):
)
return json.dumps({"input": a_input.to_dict()})
def describe_input(self):
def describe_input(self) -> str:
input_id = self._get_param("inputId")
return json.dumps(self.medialive_backend.describe_input(input_id=input_id))
a_input = self.medialive_backend.describe_input(input_id=input_id)
return json.dumps(a_input.to_dict())
def list_inputs(self):
def list_inputs(self) -> str:
max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken")
inputs, next_token = self.medialive_backend.list_inputs(
max_results=max_results, next_token=next_token
)
inputs = self.medialive_backend.list_inputs(max_results=max_results)
return json.dumps(dict(inputs=inputs, nextToken=next_token))
return json.dumps(dict(inputs=inputs, nextToken=None))
def delete_input(self):
def delete_input(self) -> str:
input_id = self._get_param("inputId")
self.medialive_backend.delete_input(input_id=input_id)
return json.dumps({})
def update_input(self):
def update_input(self) -> str:
destinations = self._get_param("destinations")
input_devices = self._get_param("inputDevices")
input_id = self._get_param("inputId")

View File

@ -7,5 +7,5 @@ class MediaPackageClientError(JsonRESTError):
# AWS service exceptions are caught with the underlying botocore exception, ClientError
class ClientError(MediaPackageClientError):
def __init__(self, error, message):
def __init__(self, error: str, message: str):
super().__init__(error, message)

View File

@ -1,4 +1,5 @@
from collections import OrderedDict
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel
@ -6,27 +7,23 @@ from .exceptions import ClientError
class Channel(BaseModel):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn")
self.channel_id = kwargs.get("channel_id")
self.description = kwargs.get("description")
self.tags = kwargs.get("tags")
def to_dict(self, exclude=None):
data = {
def to_dict(self) -> Dict[str, Any]:
return {
"arn": self.arn,
"id": self.channel_id,
"description": self.description,
"tags": self.tags,
}
if exclude:
for key in exclude:
del data[key]
return data
class OriginEndpoint(BaseModel):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn")
self.authorization = kwargs.get("authorization")
self.channel_id = kwargs.get("channel_id")
@ -44,8 +41,8 @@ class OriginEndpoint(BaseModel):
self.url = kwargs.get("url")
self.whitelist = kwargs.get("whitelist")
def to_dict(self):
data = {
def to_dict(self) -> Dict[str, Any]:
return {
"arn": self.arn,
"authorization": self.authorization,
"channelId": self.channel_id,
@ -63,69 +60,62 @@ class OriginEndpoint(BaseModel):
"url": self.url,
"whitelist": self.whitelist,
}
return data
class MediaPackageBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self._channels = OrderedDict()
self._origin_endpoints = OrderedDict()
self._channels: Dict[str, Channel] = OrderedDict()
self._origin_endpoints: Dict[str, OriginEndpoint] = OrderedDict()
def create_channel(self, description, channel_id, tags):
def create_channel(
self, description: str, channel_id: str, tags: Dict[str, str]
) -> Channel:
arn = f"arn:aws:mediapackage:channel:{channel_id}"
channel = Channel(
arn=arn,
description=description,
egress_access_logs={},
hls_ingest={},
channel_id=channel_id,
ingress_access_logs={},
tags=tags,
)
self._channels[channel_id] = channel
return channel
def list_channels(self):
channels = list(self._channels.values())
response_channels = [c.to_dict() for c in channels]
return response_channels
def list_channels(self) -> List[Dict[str, Any]]:
return [c.to_dict() for c in self._channels.values()]
def describe_channel(self, channel_id):
def describe_channel(self, channel_id: str) -> Channel:
try:
channel = self._channels[channel_id]
return channel.to_dict()
return self._channels[channel_id]
except KeyError:
error = "NotFoundException"
raise ClientError(error, f"channel with id={channel_id} not found")
raise ClientError(
"NotFoundException", f"channel with id={channel_id} not found"
)
def delete_channel(self, channel_id):
try:
channel = self._channels[channel_id]
del self._channels[channel_id]
return channel.to_dict()
except KeyError:
error = "NotFoundException"
raise ClientError(error, f"channel with id={channel_id} not found")
def delete_channel(self, channel_id: str) -> Channel:
if channel_id in self._channels:
return self._channels.pop(channel_id)
raise ClientError(
"NotFoundException", f"channel with id={channel_id} not found"
)
def create_origin_endpoint(
self,
authorization,
channel_id,
cmaf_package,
dash_package,
description,
hls_package,
endpoint_id,
manifest_name,
mss_package,
origination,
startover_window_seconds,
tags,
time_delay_seconds,
whitelist,
):
authorization: Dict[str, str],
channel_id: str,
cmaf_package: Dict[str, Any],
dash_package: Dict[str, Any],
description: str,
hls_package: Dict[str, Any],
endpoint_id: str,
manifest_name: str,
mss_package: Dict[str, Any],
origination: str,
startover_window_seconds: int,
tags: Dict[str, str],
time_delay_seconds: int,
whitelist: List[str],
) -> OriginEndpoint:
arn = f"arn:aws:mediapackage:origin_endpoint:{endpoint_id}"
url = f"https://origin-endpoint.mediapackage.{self.region_name}.amazonaws.com/{endpoint_id}"
origin_endpoint = OriginEndpoint(
@ -149,43 +139,39 @@ class MediaPackageBackend(BaseBackend):
self._origin_endpoints[endpoint_id] = origin_endpoint
return origin_endpoint
def describe_origin_endpoint(self, endpoint_id):
def describe_origin_endpoint(self, endpoint_id: str) -> OriginEndpoint:
try:
origin_endpoint = self._origin_endpoints[endpoint_id]
return origin_endpoint.to_dict()
return self._origin_endpoints[endpoint_id]
except KeyError:
error = "NotFoundException"
raise ClientError(error, f"origin endpoint with id={endpoint_id} not found")
raise ClientError(
"NotFoundException", f"origin endpoint with id={endpoint_id} not found"
)
def list_origin_endpoints(self):
origin_endpoints = list(self._origin_endpoints.values())
response_origin_endpoints = [o.to_dict() for o in origin_endpoints]
return response_origin_endpoints
def list_origin_endpoints(self) -> List[Dict[str, Any]]:
return [o.to_dict() for o in self._origin_endpoints.values()]
def delete_origin_endpoint(self, endpoint_id):
try:
origin_endpoint = self._origin_endpoints[endpoint_id]
del self._origin_endpoints[endpoint_id]
return origin_endpoint.to_dict()
except KeyError:
error = "NotFoundException"
raise ClientError(error, f"origin endpoint with id={endpoint_id} not found")
def delete_origin_endpoint(self, endpoint_id: str) -> OriginEndpoint:
if endpoint_id in self._origin_endpoints:
return self._origin_endpoints.pop(endpoint_id)
raise ClientError(
"NotFoundException", f"origin endpoint with id={endpoint_id} not found"
)
def update_origin_endpoint(
self,
authorization,
cmaf_package,
dash_package,
description,
hls_package,
endpoint_id,
manifest_name,
mss_package,
origination,
startover_window_seconds,
time_delay_seconds,
whitelist,
):
authorization: Dict[str, str],
cmaf_package: Dict[str, Any],
dash_package: Dict[str, Any],
description: str,
hls_package: Dict[str, Any],
endpoint_id: str,
manifest_name: str,
mss_package: Dict[str, Any],
origination: str,
startover_window_seconds: int,
time_delay_seconds: int,
whitelist: List[str],
) -> OriginEndpoint:
try:
origin_endpoint = self._origin_endpoints[endpoint_id]
origin_endpoint.authorization = authorization
@ -200,10 +186,10 @@ class MediaPackageBackend(BaseBackend):
origin_endpoint.time_delay_seconds = time_delay_seconds
origin_endpoint.whitelist = whitelist
return origin_endpoint
except KeyError:
error = "NotFoundException"
raise ClientError(error, f"origin endpoint with id={endpoint_id} not found")
raise ClientError(
"NotFoundException", f"origin endpoint with id={endpoint_id} not found"
)
mediapackage_backends = BackendDict(MediaPackageBackend, "mediapackage")

View File

@ -1,17 +1,17 @@
from moto.core.responses import BaseResponse
from .models import mediapackage_backends
from .models import mediapackage_backends, MediaPackageBackend
import json
class MediaPackageResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="mediapackage")
@property
def mediapackage_backend(self):
def mediapackage_backend(self) -> MediaPackageBackend:
return mediapackage_backends[self.current_account][self.region]
def create_channel(self):
def create_channel(self) -> str:
description = self._get_param("description")
channel_id = self._get_param("id")
tags = self._get_param("tags")
@ -20,23 +20,21 @@ class MediaPackageResponse(BaseResponse):
)
return json.dumps(channel.to_dict())
def list_channels(self):
def list_channels(self) -> str:
channels = self.mediapackage_backend.list_channels()
return json.dumps(dict(channels=channels))
def describe_channel(self):
def describe_channel(self) -> str:
channel_id = self._get_param("id")
return json.dumps(
self.mediapackage_backend.describe_channel(channel_id=channel_id)
)
channel = self.mediapackage_backend.describe_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def delete_channel(self):
def delete_channel(self) -> str:
channel_id = self._get_param("id")
return json.dumps(
self.mediapackage_backend.delete_channel(channel_id=channel_id)
)
channel = self.mediapackage_backend.delete_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def create_origin_endpoint(self):
def create_origin_endpoint(self) -> str:
authorization = self._get_param("authorization")
channel_id = self._get_param("channelId")
cmaf_package = self._get_param("cmafPackage")
@ -65,27 +63,29 @@ class MediaPackageResponse(BaseResponse):
startover_window_seconds=startover_window_seconds,
tags=tags,
time_delay_seconds=time_delay_seconds,
whitelist=whitelist,
whitelist=whitelist, # type: ignore[arg-type]
)
return json.dumps(origin_endpoint.to_dict())
def list_origin_endpoints(self):
def list_origin_endpoints(self) -> str:
origin_endpoints = self.mediapackage_backend.list_origin_endpoints()
return json.dumps(dict(originEndpoints=origin_endpoints))
def describe_origin_endpoint(self):
def describe_origin_endpoint(self) -> str:
endpoint_id = self._get_param("id")
return json.dumps(
self.mediapackage_backend.describe_origin_endpoint(endpoint_id=endpoint_id)
endpoint = self.mediapackage_backend.describe_origin_endpoint(
endpoint_id=endpoint_id
)
return json.dumps(endpoint.to_dict())
def delete_origin_endpoint(self):
def delete_origin_endpoint(self) -> str:
endpoint_id = self._get_param("id")
return json.dumps(
self.mediapackage_backend.delete_origin_endpoint(endpoint_id=endpoint_id)
endpoint = self.mediapackage_backend.delete_origin_endpoint(
endpoint_id=endpoint_id
)
return json.dumps(endpoint.to_dict())
def update_origin_endpoint(self):
def update_origin_endpoint(self) -> str:
authorization = self._get_param("authorization")
cmaf_package = self._get_param("cmafPackage")
dash_package = self._get_param("dashPackage")
@ -110,6 +110,6 @@ class MediaPackageResponse(BaseResponse):
origination=origination,
startover_window_seconds=startover_window_seconds,
time_delay_seconds=time_delay_seconds,
whitelist=whitelist,
whitelist=whitelist, # type: ignore[arg-type]
)
return json.dumps(origin_endpoint.to_dict())

View File

@ -1,3 +1,4 @@
from typing import Optional
from moto.core.exceptions import JsonRESTError
@ -6,7 +7,7 @@ class MediaStoreClientError(JsonRESTError):
class ContainerNotFoundException(MediaStoreClientError):
def __init__(self, msg=None):
def __init__(self, msg: Optional[str] = None):
self.code = 400
super().__init__(
"ContainerNotFoundException",
@ -15,7 +16,7 @@ class ContainerNotFoundException(MediaStoreClientError):
class ResourceNotFoundException(MediaStoreClientError):
def __init__(self, msg=None):
def __init__(self, msg: Optional[str] = None):
self.code = 400
super().__init__(
"ResourceNotFoundException", msg or "The specified container does not exist"
@ -23,7 +24,7 @@ class ResourceNotFoundException(MediaStoreClientError):
class PolicyNotFoundException(MediaStoreClientError):
def __init__(self, msg=None):
def __init__(self, msg: Optional[str] = None):
self.code = 400
super().__init__(
"PolicyNotFoundException",

View File

@ -1,5 +1,6 @@
from collections import OrderedDict
from datetime import date
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import (
@ -10,18 +11,18 @@ from .exceptions import (
class Container(BaseModel):
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn")
self.name = kwargs.get("name")
self.endpoint = kwargs.get("endpoint")
self.status = kwargs.get("status")
self.creation_time = kwargs.get("creation_time")
self.lifecycle_policy = None
self.policy = None
self.metric_policy = None
self.lifecycle_policy: Optional[str] = None
self.policy: Optional[str] = None
self.metric_policy: Optional[str] = None
self.tags = kwargs.get("tags")
def to_dict(self, exclude=None):
def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]:
data = {
"ARN": self.arn,
"Name": self.name,
@ -37,11 +38,11 @@ class Container(BaseModel):
class MediaStoreBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self._containers = OrderedDict()
self._containers: Dict[str, Container] = OrderedDict()
def create_container(self, name, tags):
def create_container(self, name: str, tags: Dict[str, str]) -> Container:
arn = f"arn:aws:mediastore:container:{name}"
container = Container(
arn=arn,
@ -54,40 +55,36 @@ class MediaStoreBackend(BaseBackend):
self._containers[name] = container
return container
def delete_container(self, name):
def delete_container(self, name: str) -> None:
if name not in self._containers:
raise ContainerNotFoundException()
del self._containers[name]
return {}
def describe_container(self, name):
def describe_container(self, name: str) -> Container:
if name not in self._containers:
raise ResourceNotFoundException()
container = self._containers[name]
container.status = "ACTIVE"
return container
def list_containers(self):
def list_containers(self) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
containers = list(self._containers.values())
response_containers = [c.to_dict() for c in containers]
return response_containers, None
return [c.to_dict() for c in self._containers.values()]
def list_tags_for_resource(self, name):
def list_tags_for_resource(self, name: str) -> Optional[Dict[str, str]]:
if name not in self._containers:
raise ContainerNotFoundException()
tags = self._containers[name].tags
return tags
def put_lifecycle_policy(self, container_name, lifecycle_policy):
def put_lifecycle_policy(self, container_name: str, lifecycle_policy: str) -> None:
if container_name not in self._containers:
raise ResourceNotFoundException()
self._containers[container_name].lifecycle_policy = lifecycle_policy
return {}
def get_lifecycle_policy(self, container_name):
def get_lifecycle_policy(self, container_name: str) -> str:
if container_name not in self._containers:
raise ResourceNotFoundException()
lifecycle_policy = self._containers[container_name].lifecycle_policy
@ -95,13 +92,12 @@ class MediaStoreBackend(BaseBackend):
raise PolicyNotFoundException()
return lifecycle_policy
def put_container_policy(self, container_name, policy):
def put_container_policy(self, container_name: str, policy: str) -> None:
if container_name not in self._containers:
raise ResourceNotFoundException()
self._containers[container_name].policy = policy
return {}
def get_container_policy(self, container_name):
def get_container_policy(self, container_name: str) -> str:
if container_name not in self._containers:
raise ResourceNotFoundException()
policy = self._containers[container_name].policy
@ -109,13 +105,12 @@ class MediaStoreBackend(BaseBackend):
raise PolicyNotFoundException()
return policy
def put_metric_policy(self, container_name, metric_policy):
def put_metric_policy(self, container_name: str, metric_policy: str) -> None:
if container_name not in self._containers:
raise ResourceNotFoundException()
self._containers[container_name].metric_policy = metric_policy
return {}
def get_metric_policy(self, container_name):
def get_metric_policy(self, container_name: str) -> str:
if container_name not in self._containers:
raise ResourceNotFoundException()
metric_policy = self._containers[container_name].metric_policy

View File

@ -1,73 +1,73 @@
import json
from moto.core.responses import BaseResponse
from .models import mediastore_backends
from .models import mediastore_backends, MediaStoreBackend
class MediaStoreResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="mediastore")
@property
def mediastore_backend(self):
def mediastore_backend(self) -> MediaStoreBackend:
return mediastore_backends[self.current_account][self.region]
def create_container(self):
def create_container(self) -> str:
name = self._get_param("ContainerName")
tags = self._get_param("Tags")
container = self.mediastore_backend.create_container(name=name, tags=tags)
return json.dumps(dict(Container=container.to_dict()))
def delete_container(self):
def delete_container(self) -> str:
name = self._get_param("ContainerName")
result = self.mediastore_backend.delete_container(name=name)
return json.dumps(result)
self.mediastore_backend.delete_container(name=name)
return "{}"
def describe_container(self):
def describe_container(self) -> str:
name = self._get_param("ContainerName")
container = self.mediastore_backend.describe_container(name=name)
return json.dumps(dict(Container=container.to_dict()))
def list_containers(self):
containers, next_token = self.mediastore_backend.list_containers()
return json.dumps(dict(dict(Containers=containers), NextToken=next_token))
def list_containers(self) -> str:
containers = self.mediastore_backend.list_containers()
return json.dumps(dict(dict(Containers=containers), NextToken=None))
def list_tags_for_resource(self):
def list_tags_for_resource(self) -> str:
name = self._get_param("Resource")
tags = self.mediastore_backend.list_tags_for_resource(name)
return json.dumps(dict(Tags=tags))
def put_lifecycle_policy(self):
def put_lifecycle_policy(self) -> str:
container_name = self._get_param("ContainerName")
lifecycle_policy = self._get_param("LifecyclePolicy")
policy = self.mediastore_backend.put_lifecycle_policy(
self.mediastore_backend.put_lifecycle_policy(
container_name=container_name, lifecycle_policy=lifecycle_policy
)
return json.dumps(policy)
return "{}"
def get_lifecycle_policy(self):
def get_lifecycle_policy(self) -> str:
container_name = self._get_param("ContainerName")
lifecycle_policy = self.mediastore_backend.get_lifecycle_policy(
container_name=container_name
)
return json.dumps(dict(LifecyclePolicy=lifecycle_policy))
def put_container_policy(self):
def put_container_policy(self) -> str:
container_name = self._get_param("ContainerName")
policy = self._get_param("Policy")
container_policy = self.mediastore_backend.put_container_policy(
self.mediastore_backend.put_container_policy(
container_name=container_name, policy=policy
)
return json.dumps(container_policy)
return "{}"
def get_container_policy(self):
def get_container_policy(self) -> str:
container_name = self._get_param("ContainerName")
policy = self.mediastore_backend.get_container_policy(
container_name=container_name
)
return json.dumps(dict(Policy=policy))
def put_metric_policy(self):
def put_metric_policy(self) -> str:
container_name = self._get_param("ContainerName")
metric_policy = self._get_param("MetricPolicy")
self.mediastore_backend.put_metric_policy(
@ -75,12 +75,9 @@ class MediaStoreResponse(BaseResponse):
)
return json.dumps(metric_policy)
def get_metric_policy(self):
def get_metric_policy(self) -> str:
container_name = self._get_param("ContainerName")
metric_policy = self.mediastore_backend.get_metric_policy(
container_name=container_name
)
return json.dumps(dict(MetricPolicy=metric_policy))
# add templates from here

View File

@ -7,5 +7,5 @@ class MediaStoreDataClientError(JsonRESTError):
# AWS service exceptions are caught with the underlying botocore exception, ClientError
class ClientError(MediaStoreDataClientError):
def __init__(self, error, message):
def __init__(self, error: str, message: str):
super().__init__(error, message)

View File

@ -1,20 +1,23 @@
import hashlib
from collections import OrderedDict
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import ClientError
class Object(BaseModel):
def __init__(self, path, body, etag, storage_class="TEMPORAL"):
def __init__(
self, path: str, body: str, etag: str, storage_class: str = "TEMPORAL"
):
self.path = path
self.body = body
self.content_sha256 = hashlib.sha256(body.encode("utf-8")).hexdigest()
self.etag = etag
self.storage_class = storage_class
def to_dict(self):
data = {
def to_dict(self) -> Dict[str, Any]:
return {
"ETag": self.etag,
"Name": self.path,
"Type": "FILE",
@ -24,15 +27,15 @@ class Object(BaseModel):
"ContentSHA256": self.content_sha256,
}
return data
class MediaStoreDataBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self._objects = OrderedDict()
self._objects: Dict[str, Object] = OrderedDict()
def put_object(self, body, path, storage_class="TEMPORAL"):
def put_object(
self, body: str, path: str, storage_class: str = "TEMPORAL"
) -> Object:
"""
The following parameters are not yet implemented: ContentType, CacheControl, UploadAvailability
"""
@ -42,30 +45,29 @@ class MediaStoreDataBackend(BaseBackend):
self._objects[path] = new_object
return new_object
def delete_object(self, path):
def delete_object(self, path: str) -> None:
if path not in self._objects:
error = "ObjectNotFoundException"
raise ClientError(error, f"Object with id={path} not found")
raise ClientError(
"ObjectNotFoundException", f"Object with id={path} not found"
)
del self._objects[path]
return {}
def get_object(self, path):
def get_object(self, path: str) -> Object:
"""
The Range-parameter is not yet supported.
"""
objects_found = [item for item in self._objects.values() if item.path == path]
if len(objects_found) == 0:
error = "ObjectNotFoundException"
raise ClientError(error, f"Object with id={path} not found")
raise ClientError(
"ObjectNotFoundException", f"Object with id={path} not found"
)
return objects_found[0]
def list_items(self):
def list_items(self) -> List[Dict[str, Any]]:
"""
The Path- and MaxResults-parameters are not yet supported.
"""
items = self._objects.values()
response_items = [c.to_dict() for c in items]
return response_items
return [c.to_dict() for c in self._objects.values()]
mediastoredata_backends = BackendDict(MediaStoreDataBackend, "mediastore-data")

View File

@ -1,35 +1,36 @@
import json
from typing import Dict, Tuple
from moto.core.responses import BaseResponse
from .models import mediastoredata_backends
from .models import mediastoredata_backends, MediaStoreDataBackend
class MediaStoreDataResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="mediastore-data")
@property
def mediastoredata_backend(self):
def mediastoredata_backend(self) -> MediaStoreDataBackend:
return mediastoredata_backends[self.current_account][self.region]
def get_object(self):
def get_object(self) -> Tuple[str, Dict[str, str]]:
path = self._get_param("Path")
result = self.mediastoredata_backend.get_object(path=path)
headers = {"Path": result.path}
return result.body, headers
def put_object(self):
def put_object(self) -> str:
body = self.body
path = self._get_param("Path")
new_object = self.mediastoredata_backend.put_object(body, path)
object_dict = new_object.to_dict()
return json.dumps(object_dict)
def delete_object(self):
def delete_object(self) -> str:
item_id = self._get_param("Path")
result = self.mediastoredata_backend.delete_object(path=item_id)
return json.dumps(result)
self.mediastoredata_backend.delete_object(path=item_id)
return "{}"
def list_items(self):
def list_items(self) -> str:
items = self.mediastoredata_backend.list_items()
return json.dumps(dict(Items=items))

View File

@ -1,38 +0,0 @@
from moto.core.exceptions import JsonRESTError
class DisabledApiException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="DisabledApiException", message=message)
class InternalServiceErrorException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="InternalServiceErrorException", message=message)
class InvalidCustomerIdentifierException(JsonRESTError):
def __init__(self, message):
super().__init__(
error_type="InvalidCustomerIdentifierException", message=message
)
class InvalidProductCodeException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="InvalidProductCodeException", message=message)
class InvalidUsageDimensionException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="InvalidUsageDimensionException", message=message)
class ThrottlingException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="ThrottlingException", message=message)
class TimestampOutOfBoundsException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="TimestampOutOfBoundsException", message=message)

View File

@ -1,10 +1,17 @@
import collections
from typing import Any, Deque, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random
class UsageRecord(BaseModel, dict):
def __init__(self, timestamp, customer_identifier, dimension, quantity=0):
class UsageRecord(BaseModel, Dict[str, Any]): # type: ignore[misc]
def __init__(
self,
timestamp: str,
customer_identifier: str,
dimension: str,
quantity: int = 0,
):
super().__init__()
self.timestamp = timestamp
self.customer_identifier = customer_identifier
@ -12,54 +19,45 @@ class UsageRecord(BaseModel, dict):
self.quantity = quantity
self.metering_record_id = mock_random.uuid4().hex
@classmethod
def from_data(cls, data):
cls(
timestamp=data.get("Timestamp"),
customer_identifier=data.get("CustomerIdentifier"),
dimension=data.get("Dimension"),
quantity=data.get("Quantity", 0),
)
@property
def timestamp(self):
def timestamp(self) -> str:
return self["Timestamp"]
@timestamp.setter
def timestamp(self, value):
def timestamp(self, value: str) -> None:
self["Timestamp"] = value
@property
def customer_identifier(self):
def customer_identifier(self) -> str:
return self["CustomerIdentifier"]
@customer_identifier.setter
def customer_identifier(self, value):
def customer_identifier(self, value: str) -> None:
self["CustomerIdentifier"] = value
@property
def dimension(self):
def dimension(self) -> str:
return self["Dimension"]
@dimension.setter
def dimension(self, value):
def dimension(self, value: str) -> None:
self["Dimension"] = value
@property
def quantity(self):
def quantity(self) -> int:
return self["Quantity"]
@quantity.setter
def quantity(self, value):
def quantity(self, value: int) -> None:
self["Quantity"] = value
class Result(BaseModel, dict):
class Result(BaseModel, Dict[str, Any]): # type: ignore[misc]
SUCCESS = "Success"
CUSTOMER_NOT_SUBSCRIBED = "CustomerNotSubscribed"
DUPLICATE_RECORD = "DuplicateRecord"
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
self.usage_record = UsageRecord(
timestamp=kwargs["Timestamp"],
customer_identifier=kwargs["CustomerIdentifier"],
@ -70,28 +68,26 @@ class Result(BaseModel, dict):
self["MeteringRecordId"] = self.usage_record.metering_record_id
@property
def metering_record_id(self):
def metering_record_id(self) -> str:
return self["MeteringRecordId"]
@property
def status(self):
def status(self) -> str:
return self["Status"]
@status.setter
def status(self, value):
def status(self, value: str) -> None:
self["Status"] = value
@property
def usage_record(self):
def usage_record(self) -> UsageRecord:
return self["UsageRecord"]
@usage_record.setter
def usage_record(self, value):
if not isinstance(value, UsageRecord):
value = UsageRecord.from_data(value)
def usage_record(self, value: UsageRecord) -> None:
self["UsageRecord"] = value
def is_duplicate(self, other):
def is_duplicate(self, other: Any) -> bool:
"""
DuplicateRecord - Indicates that the UsageRecord was invalid and not honored.
A previously metered UsageRecord had the same customer, dimension, and time,
@ -107,23 +103,29 @@ class Result(BaseModel, dict):
)
class CustomerDeque(collections.deque):
def is_subscribed(self, customer):
class CustomerDeque(Deque[str]):
def is_subscribed(self, customer: str) -> bool:
return customer in self
class ResultDeque(collections.deque):
def is_duplicate(self, result):
class ResultDeque(Deque[Result]):
def is_duplicate(self, result: Result) -> bool:
return any(record.is_duplicate(result) for record in self)
class MeteringMarketplaceBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.customers_by_product = collections.defaultdict(CustomerDeque)
self.records_by_product = collections.defaultdict(ResultDeque)
self.customers_by_product: Dict[str, CustomerDeque] = collections.defaultdict(
CustomerDeque
)
self.records_by_product: Dict[str, ResultDeque] = collections.defaultdict(
ResultDeque
)
def batch_meter_usage(self, product_code, usage_records):
def batch_meter_usage(
self, product_code: str, usage_records: List[Dict[str, Any]]
) -> List[Result]:
results = []
for usage in usage_records:
result = Result(**usage)

View File

@ -9,8 +9,7 @@ class MarketplaceMeteringResponse(BaseResponse):
def backend(self) -> MeteringMarketplaceBackend:
return meteringmarketplace_backends[self.current_account][self.region]
def batch_meter_usage(self):
results = []
def batch_meter_usage(self) -> str:
usage_records = json.loads(self.body)["UsageRecords"]
product_code = json.loads(self.body)["ProductCode"]
results = self.backend.batch_meter_usage(product_code, usage_records)

View File

@ -1,29 +1,32 @@
import time
from threading import Thread
from werkzeug.serving import make_server
from typing import Optional
from werkzeug.serving import make_server, BaseWSGIServer
from .werkzeug_app import DomainDispatcherApplication, create_backend_app
class ThreadedMotoServer:
def __init__(self, ip_address="0.0.0.0", port=5000, verbose=True):
def __init__(
self, ip_address: str = "0.0.0.0", port: int = 5000, verbose: bool = True
):
self._port = port
self._thread = None
self._thread: Optional[Thread] = None
self._ip_address = ip_address
self._server = None
self._server: Optional[BaseWSGIServer] = None
self._server_ready = False
self._verbose = verbose
def _server_entry(self):
def _server_entry(self) -> None:
app = DomainDispatcherApplication(create_backend_app)
self._server = make_server(self._ip_address, self._port, app, True)
self._server_ready = True
self._server.serve_forever()
def start(self):
def start(self) -> None:
if self._verbose:
print( # noqa
f"Starting a new Thread with MotoServer running on {self._ip_address}:{self._port}..."
@ -33,9 +36,9 @@ class ThreadedMotoServer:
while not self._server_ready:
time.sleep(0.1)
def stop(self):
def stop(self) -> None:
self._server_ready = False
if self._server:
self._server.shutdown()
self._thread.join()
self._thread.join() # type: ignore[union-attr]

View File

@ -1,6 +1,6 @@
import json
from flask.testing import FlaskClient
from typing import Any, Dict
from urllib.parse import urlencode
from werkzeug.routing import BaseConverter
@ -10,13 +10,13 @@ class RegexConverter(BaseConverter):
part_isolating = False
def __init__(self, url_map, *items):
def __init__(self, url_map: Any, *items: Any):
super().__init__(url_map)
self.regex = items[0]
class AWSTestHelper(FlaskClient):
def action_data(self, action_name, **kwargs):
def action_data(self, action_name: str, **kwargs: Any) -> str:
"""
Method calls resource with action_name and returns data of response.
"""
@ -24,11 +24,11 @@ class AWSTestHelper(FlaskClient):
opts.update(kwargs)
res = self.get(
f"/?{urlencode(opts)}",
headers={"Host": f"{self.application.service}.us-east-1.amazonaws.com"},
headers={"Host": f"{self.application.service}.us-east-1.amazonaws.com"}, # type: ignore[attr-defined]
)
return res.data.decode("utf-8")
def action_json(self, action_name, **kwargs):
def action_json(self, action_name: str, **kwargs: Any) -> Dict[str, Any]:
"""
Method calls resource with action_name and returns object obtained via
deserialization of output.

View File

@ -2,6 +2,7 @@ import io
import os
import os.path
from threading import Lock
from typing import Any, Callable, Dict, Optional, Tuple
try:
from flask import Flask
@ -49,20 +50,22 @@ SIGNING_ALIASES = {
SERVICE_BY_VERSION = {"2009-04-15": "sdb"}
class DomainDispatcherApplication(object):
class DomainDispatcherApplication:
"""
Dispatch requests to different applications based on the "Host:" header
value. We'll match the host header value with the url_bases of each backend.
"""
def __init__(self, create_app, service=None):
def __init__(
self, create_app: Callable[[str], Flask], service: Optional[str] = None
):
self.create_app = create_app
self.lock = Lock()
self.app_instances = {}
self.app_instances: Dict[str, Flask] = {}
self.service = service
self.backend_url_patterns = backend_index.backend_url_patterns
def get_backend_for_host(self, host):
def get_backend_for_host(self, host: str) -> Any:
if host == "moto_api":
return host
@ -83,7 +86,9 @@ class DomainDispatcherApplication(object):
"Remember to add the URL to urls.py, and run scripts/update_backend_index.py to index it."
)
def infer_service_region_host(self, body, environ):
def infer_service_region_host(
self, body: Optional[str], environ: Dict[str, Any]
) -> str:
auth = environ.get("HTTP_AUTHORIZATION")
target = environ.get("HTTP_X_AMZ_TARGET")
service = None
@ -111,7 +116,7 @@ class DomainDispatcherApplication(object):
service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION)
elif action and action in UNSIGNED_ACTIONS:
# See if we can match the Action to a known service
service, region = UNSIGNED_ACTIONS.get(action)
service, region = UNSIGNED_ACTIONS[action]
if not service:
service, region = self.get_service_from_body(body, environ)
if not service:
@ -153,7 +158,7 @@ class DomainDispatcherApplication(object):
return host
def get_application(self, environ):
def get_application(self, environ: Dict[str, Any]) -> Flask:
path_info = environ.get("PATH_INFO", "")
# The URL path might contain non-ASCII text, for instance unicode S3 bucket names
@ -181,7 +186,7 @@ class DomainDispatcherApplication(object):
self.app_instances[backend] = app
return app
def _get_body(self, environ):
def _get_body(self, environ: Dict[str, Any]) -> Optional[str]:
body = None
try:
# AWS requests use querystrings as the body (Action=x&Data=y&...)
@ -190,7 +195,7 @@ class DomainDispatcherApplication(object):
)
request_body_size = int(environ["CONTENT_LENGTH"])
if simple_form and request_body_size:
body = environ["wsgi.input"].read(request_body_size).decode("utf-8")
body = environ["wsgi.input"].read(request_body_size).decode("utf-8") # type: ignore
except (KeyError, ValueError):
pass
finally:
@ -199,7 +204,9 @@ class DomainDispatcherApplication(object):
environ["wsgi.input"] = io.StringIO(body)
return body
def get_service_from_body(self, body, environ):
def get_service_from_body(
self, body: Optional[str], environ: Dict[str, Any]
) -> Tuple[Optional[str], Optional[str]]:
# Some services have the SDK Version in the body
# If the version is unique, we can derive the service from it
version = self.get_version_from_body(body)
@ -209,22 +216,24 @@ class DomainDispatcherApplication(object):
return SERVICE_BY_VERSION[version], region
return None, None
def get_version_from_body(self, body):
def get_version_from_body(self, body: Optional[str]) -> Optional[str]:
try:
body_dict = dict(x.split("=") for x in body.split("&"))
body_dict = dict(x.split("=") for x in body.split("&")) # type: ignore
return body_dict["Version"]
except (AttributeError, KeyError, ValueError):
return None
def get_action_from_body(self, body):
def get_action_from_body(self, body: Optional[str]) -> Optional[str]:
try:
# AWS requests use querystrings as the body (Action=x&Data=y&...)
body_dict = dict(x.split("=") for x in body.split("&"))
body_dict = dict(x.split("=") for x in body.split("&")) # type: ignore
return body_dict["Action"]
except (AttributeError, KeyError, ValueError):
return None
def get_service_from_path(self, environ):
def get_service_from_path(
self, environ: Dict[str, Any]
) -> Tuple[Optional[str], Optional[str]]:
# Moto sometimes needs to send a HTTP request to itself
# In which case it will send a request to 'http://localhost/service_region/whatever'
try:
@ -234,12 +243,12 @@ class DomainDispatcherApplication(object):
except (AttributeError, KeyError, ValueError):
return None, None
def __call__(self, environ, start_response):
def __call__(self, environ: Dict[str, Any], start_response: Any) -> Any:
backend_app = self.get_application(environ)
return backend_app(environ, start_response)
def create_backend_app(service):
def create_backend_app(service: str) -> Flask:
from werkzeug.routing import Map
current_file = os.path.abspath(__file__)
@ -249,7 +258,7 @@ def create_backend_app(service):
# Create the backend_app
backend_app = Flask("moto", template_folder=template_dir)
backend_app.debug = True
backend_app.service = service
backend_app.service = service # type: ignore[attr-defined]
CORS(backend_app)
# Reset view functions to reset the app

View File

@ -1,4 +1,5 @@
import json
from typing import Any
from moto.core.exceptions import JsonRESTError
@ -7,11 +8,13 @@ class MQError(JsonRESTError):
class UnknownBroker(MQError):
def __init__(self, broker_id):
def __init__(self, broker_id: str):
super().__init__("NotFoundException", "Can't find requested broker")
self.broker_id = broker_id
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = {
"errorAttribute": "broker-id",
"message": f"Can't find requested broker [{self.broker_id}]. Make sure your broker exists.",
@ -20,11 +23,13 @@ class UnknownBroker(MQError):
class UnknownConfiguration(MQError):
def __init__(self, config_id):
def __init__(self, config_id: str):
super().__init__("NotFoundException", "Can't find requested configuration")
self.config_id = config_id
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = {
"errorAttribute": "configuration_id",
"message": f"Can't find requested configuration [{self.config_id}]. Make sure your configuration exists.",
@ -33,11 +38,13 @@ class UnknownConfiguration(MQError):
class UnknownUser(MQError):
def __init__(self, username):
def __init__(self, username: str):
super().__init__("NotFoundException", "Can't find requested user")
self.username = username
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = {
"errorAttribute": "username",
"message": f"Can't find requested user [{self.username}]. Make sure your user exists.",
@ -46,11 +53,13 @@ class UnknownUser(MQError):
class UnsupportedEngineType(MQError):
def __init__(self, engine_type):
def __init__(self, engine_type: str):
super().__init__("BadRequestException", "")
self.engine_type = engine_type
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = {
"errorAttribute": "engineType",
"message": f"Broker engine type [{self.engine_type}] does not support configuration.",
@ -59,11 +68,13 @@ class UnsupportedEngineType(MQError):
class UnknownEngineType(MQError):
def __init__(self, engine_type):
def __init__(self, engine_type: str):
super().__init__("BadRequestException", "")
self.engine_type = engine_type
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = {
"errorAttribute": "engineType",
"message": f"Broker engine type [{self.engine_type}] is invalid. Valid values are: [ACTIVEMQ]",

View File

@ -1,5 +1,6 @@
import base64
import xmltodict
from typing import Any, Dict, List, Iterable, Optional, Tuple
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.core.utils import unix_time
@ -17,7 +18,13 @@ from .exceptions import (
class ConfigurationRevision(BaseModel):
def __init__(self, configuration_id, revision_id, description, data=None):
def __init__(
self,
configuration_id: str,
revision_id: str,
description: str,
data: Optional[str] = None,
):
self.configuration_id = configuration_id
self.created = unix_time()
self.description = description
@ -31,7 +38,7 @@ class ConfigurationRevision(BaseModel):
else:
self.data = data
def has_ldap_auth(self):
def has_ldap_auth(self) -> bool:
try:
xml = base64.b64decode(self.data)
dct = xmltodict.parse(xml, dict_constructor=dict)
@ -45,7 +52,7 @@ class ConfigurationRevision(BaseModel):
# If anything fails, lets assume it's not LDAP
return False
def to_json(self, full=True):
def to_json(self, full: bool = True) -> Dict[str, Any]:
resp = {
"created": self.created,
"description": self.description,
@ -58,7 +65,14 @@ class ConfigurationRevision(BaseModel):
class Configuration(BaseModel):
def __init__(self, account_id, region, name, engine_type, engine_version):
def __init__(
self,
account_id: str,
region: str,
name: str,
engine_type: str,
engine_version: str,
):
self.id = f"c-{mock_random.get_random_hex(6)}"
self.arn = f"arn:aws:mq:{region}:{account_id}:configuration:{self.id}"
self.created = unix_time()
@ -67,7 +81,7 @@ class Configuration(BaseModel):
self.engine_type = engine_type
self.engine_version = engine_version
self.revisions = dict()
self.revisions: Dict[str, ConfigurationRevision] = dict()
default_desc = (
f"Auto-generated default for {self.name} on {engine_type} {engine_version}"
)
@ -80,7 +94,7 @@ class Configuration(BaseModel):
"ldap" if latest_revision.has_ldap_auth() else "simple"
)
def update(self, data, description):
def update(self, data: str, description: str) -> None:
max_revision_id, _ = sorted(self.revisions.items())[-1]
next_revision_id = str(int(max_revision_id) + 1)
latest_revision = ConfigurationRevision(
@ -95,10 +109,10 @@ class Configuration(BaseModel):
"ldap" if latest_revision.has_ldap_auth() else "simple"
)
def get_revision(self, revision_id):
def get_revision(self, revision_id: str) -> ConfigurationRevision:
return self.revisions[revision_id]
def to_json(self):
def to_json(self) -> Dict[str, Any]:
_, latest_revision = sorted(self.revisions.items())[-1]
return {
"arn": self.arn,
@ -113,22 +127,30 @@ class Configuration(BaseModel):
class User(BaseModel):
def __init__(self, broker_id, username, console_access=None, groups=None):
def __init__(
self,
broker_id: str,
username: str,
console_access: Optional[bool] = None,
groups: Optional[List[str]] = None,
):
self.broker_id = broker_id
self.username = username
self.console_access = console_access or False
self.groups = groups or []
def update(self, console_access, groups):
def update(
self, console_access: Optional[bool], groups: Optional[List[str]]
) -> None:
if console_access is not None:
self.console_access = console_access
if groups:
self.groups = groups
def summary(self):
def summary(self) -> Dict[str, str]:
return {"username": self.username}
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {
"brokerId": self.broker_id,
"username": self.username,
@ -140,25 +162,25 @@ class User(BaseModel):
class Broker(BaseModel):
def __init__(
self,
name,
account_id,
region,
authentication_strategy,
auto_minor_version_upgrade,
configuration,
deployment_mode,
encryption_options,
engine_type,
engine_version,
host_instance_type,
ldap_server_metadata,
logs,
maintenance_window_start_time,
publicly_accessible,
security_groups,
storage_type,
subnet_ids,
users,
name: str,
account_id: str,
region: str,
authentication_strategy: str,
auto_minor_version_upgrade: bool,
configuration: Dict[str, Any],
deployment_mode: str,
encryption_options: Dict[str, Any],
engine_type: str,
engine_version: str,
host_instance_type: str,
ldap_server_metadata: Dict[str, Any],
logs: Dict[str, bool],
maintenance_window_start_time: Dict[str, str],
publicly_accessible: bool,
security_groups: List[str],
storage_type: str,
subnet_ids: List[str],
users: List[Dict[str, Any]],
):
self.name = name
self.id = mock_random.get_random_hex(6)
@ -206,7 +228,7 @@ class Broker(BaseModel):
else:
self.subnet_ids = ["default-subnet"]
self.users = dict()
self.users: Dict[str, User] = dict()
for user in users:
self.create_user(
username=user["username"],
@ -215,13 +237,13 @@ class Broker(BaseModel):
)
if self.engine_type.upper() == "RABBITMQ":
self.configurations = None
self.configurations: Optional[Dict[str, Any]] = None
else:
current_config = configuration or {
"id": f"c-{mock_random.get_random_hex(6)}",
"revision": 1,
}
self.configurations = {
self.configurations = { # type: ignore[no-redef]
"current": current_config,
"history": [],
}
@ -256,23 +278,23 @@ class Broker(BaseModel):
def update(
self,
authentication_strategy,
auto_minor_version_upgrade,
configuration,
engine_version,
host_instance_type,
ldap_server_metadata,
logs,
maintenance_window_start_time,
security_groups,
):
authentication_strategy: Optional[str],
auto_minor_version_upgrade: Optional[bool],
configuration: Optional[Dict[str, Any]],
engine_version: Optional[str],
host_instance_type: Optional[str],
ldap_server_metadata: Optional[Dict[str, Any]],
logs: Optional[Dict[str, bool]],
maintenance_window_start_time: Optional[Dict[str, str]],
security_groups: Optional[List[str]],
) -> None:
if authentication_strategy:
self.authentication_strategy = authentication_strategy
if auto_minor_version_upgrade is not None:
self.auto_minor_version_upgrade = auto_minor_version_upgrade
if configuration:
self.configurations["history"].append(self.configurations["current"])
self.configurations["current"] = configuration
self.configurations["history"].append(self.configurations["current"]) # type: ignore[index]
self.configurations["current"] = configuration # type: ignore[index]
if engine_version:
self.engine_version = engine_version
if host_instance_type:
@ -286,29 +308,33 @@ class Broker(BaseModel):
if security_groups:
self.security_groups = security_groups
def reboot(self):
def reboot(self) -> None:
pass
def create_user(self, username, console_access, groups):
def create_user(
self, username: str, console_access: bool, groups: List[str]
) -> None:
user = User(self.id, username, console_access, groups)
self.users[username] = user
def update_user(self, username, console_access, groups):
def update_user(
self, username: str, console_access: bool, groups: List[str]
) -> None:
user = self.get_user(username)
user.update(console_access, groups)
def get_user(self, username):
def get_user(self, username: str) -> User:
if username not in self.users:
raise UnknownUser(username)
return self.users[username]
def delete_user(self, username):
def delete_user(self, username: str) -> None:
self.users.pop(username, None)
def list_users(self):
def list_users(self) -> Iterable[User]:
return self.users.values()
def summary(self):
def summary(self) -> Dict[str, Any]:
return {
"brokerArn": self.arn,
"brokerId": self.id,
@ -320,7 +346,7 @@ class Broker(BaseModel):
"hostInstanceType": self.host_instance_type,
}
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {
"brokerId": self.id,
"brokerArn": self.arn,
@ -352,33 +378,33 @@ class MQBackend(BaseBackend):
No EC2 integration exists yet - subnet ID's and security group values are not validated. Default values may not exist.
"""
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.brokers = dict()
self.configs = dict()
self.brokers: Dict[str, Broker] = dict()
self.configs: Dict[str, Configuration] = dict()
self.tagger = TaggingService()
def create_broker(
self,
authentication_strategy,
auto_minor_version_upgrade,
broker_name,
configuration,
deployment_mode,
encryption_options,
engine_type,
engine_version,
host_instance_type,
ldap_server_metadata,
logs,
maintenance_window_start_time,
publicly_accessible,
security_groups,
storage_type,
subnet_ids,
tags,
users,
):
authentication_strategy: str,
auto_minor_version_upgrade: bool,
broker_name: str,
configuration: Dict[str, Any],
deployment_mode: str,
encryption_options: Dict[str, Any],
engine_type: str,
engine_version: str,
host_instance_type: str,
ldap_server_metadata: Dict[str, Any],
logs: Dict[str, bool],
maintenance_window_start_time: Dict[str, str],
publicly_accessible: bool,
security_groups: List[str],
storage_type: str,
subnet_ids: List[str],
tags: Dict[str, str],
users: List[Dict[str, Any]],
) -> Tuple[str, str]:
broker = Broker(
name=broker_name,
account_id=self.account_id,
@ -404,44 +430,50 @@ class MQBackend(BaseBackend):
self.create_tags(broker.arn, tags)
return broker.arn, broker.id
def delete_broker(self, broker_id):
def delete_broker(self, broker_id: str) -> None:
del self.brokers[broker_id]
def describe_broker(self, broker_id):
def describe_broker(self, broker_id: str) -> Broker:
if broker_id not in self.brokers:
raise UnknownBroker(broker_id)
return self.brokers[broker_id]
def reboot_broker(self, broker_id):
def reboot_broker(self, broker_id: str) -> None:
self.brokers[broker_id].reboot()
def list_brokers(self):
def list_brokers(self) -> Iterable[Broker]:
"""
Pagination is not yet implemented
"""
return self.brokers.values()
def create_user(self, broker_id, username, console_access, groups):
def create_user(
self, broker_id: str, username: str, console_access: bool, groups: List[str]
) -> None:
broker = self.describe_broker(broker_id)
broker.create_user(username, console_access, groups)
def update_user(self, broker_id, console_access, groups, username):
def update_user(
self, broker_id: str, console_access: bool, groups: List[str], username: str
) -> None:
broker = self.describe_broker(broker_id)
broker.update_user(username, console_access, groups)
def describe_user(self, broker_id, username):
def describe_user(self, broker_id: str, username: str) -> User:
broker = self.describe_broker(broker_id)
return broker.get_user(username)
def delete_user(self, broker_id, username):
def delete_user(self, broker_id: str, username: str) -> None:
broker = self.describe_broker(broker_id)
broker.delete_user(username)
def list_users(self, broker_id):
def list_users(self, broker_id: str) -> Iterable[User]:
broker = self.describe_broker(broker_id)
return broker.list_users()
def create_configuration(self, name, engine_type, engine_version, tags):
def create_configuration(
self, name: str, engine_type: str, engine_version: str, tags: Dict[str, str]
) -> Configuration:
if engine_type.upper() == "RABBITMQ":
raise UnsupportedEngineType(engine_type)
if engine_type.upper() != "ACTIVEMQ":
@ -459,7 +491,9 @@ class MQBackend(BaseBackend):
)
return config
def update_configuration(self, config_id, data, description):
def update_configuration(
self, config_id: str, data: str, description: str
) -> Configuration:
"""
No validation occurs on the provided XML. The authenticationStrategy may be changed depending on the provided configuration.
"""
@ -467,47 +501,49 @@ class MQBackend(BaseBackend):
config.update(data, description)
return config
def describe_configuration(self, config_id):
def describe_configuration(self, config_id: str) -> Configuration:
if config_id not in self.configs:
raise UnknownConfiguration(config_id)
return self.configs[config_id]
def describe_configuration_revision(self, config_id, revision_id):
def describe_configuration_revision(
self, config_id: str, revision_id: str
) -> ConfigurationRevision:
config = self.configs[config_id]
return config.get_revision(revision_id)
def list_configurations(self):
def list_configurations(self) -> Iterable[Configuration]:
"""
Pagination has not yet been implemented.
"""
return self.configs.values()
def create_tags(self, resource_arn, tags):
def create_tags(self, resource_arn: str, tags: Dict[str, str]) -> None:
self.tagger.tag_resource(
resource_arn, self.tagger.convert_dict_to_tags_input(tags)
)
def list_tags(self, arn):
def list_tags(self, arn: str) -> Dict[str, str]:
return self.tagger.get_tag_dict_for_resource(arn)
def delete_tags(self, resource_arn, tag_keys):
def delete_tags(self, resource_arn: str, tag_keys: List[str]) -> None:
if not isinstance(tag_keys, list):
tag_keys = [tag_keys]
self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def update_broker(
self,
authentication_strategy,
auto_minor_version_upgrade,
broker_id,
configuration,
engine_version,
host_instance_type,
ldap_server_metadata,
logs,
maintenance_window_start_time,
security_groups,
):
authentication_strategy: str,
auto_minor_version_upgrade: bool,
broker_id: str,
configuration: Dict[str, Any],
engine_version: str,
host_instance_type: str,
ldap_server_metadata: Dict[str, Any],
logs: Dict[str, bool],
maintenance_window_start_time: Dict[str, str],
security_groups: List[str],
) -> None:
broker = self.describe_broker(broker_id)
broker.update(
authentication_strategy=authentication_strategy,

View File

@ -1,23 +1,25 @@
"""Handles incoming mq requests, invokes methods, returns responses."""
import json
from typing import Any
from urllib.parse import unquote
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from .models import mq_backends
from .models import mq_backends, MQBackend
class MQResponse(BaseResponse):
"""Handler for MQ requests and responses."""
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="mq")
@property
def mq_backend(self):
def mq_backend(self) -> MQBackend:
"""Return backend instance specific for this region."""
return mq_backends[self.current_account][self.region]
def broker(self, request, full_url, headers):
def broker(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self.describe_broker()
@ -26,40 +28,40 @@ class MQResponse(BaseResponse):
if request.method == "PUT":
return self.update_broker()
def brokers(self, request, full_url, headers):
def brokers(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self.create_broker()
if request.method == "GET":
return self.list_brokers()
def configuration(self, request, full_url, headers):
def configuration(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self.describe_configuration()
if request.method == "PUT":
return self.update_configuration()
def configurations(self, request, full_url, headers):
def configurations(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self.create_configuration()
if request.method == "GET":
return self.list_configurations()
def configuration_revision(self, request, full_url, headers):
def configuration_revision(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self.get_configuration_revision()
def tags(self, request, full_url, headers):
def tags(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self.create_tags()
if request.method == "DELETE":
return self.delete_tags()
def user(self, request, full_url, headers):
def user(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self.create_user()
@ -70,12 +72,12 @@ class MQResponse(BaseResponse):
if request.method == "DELETE":
return self.delete_user()
def users(self, request, full_url, headers):
def users(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self.list_users()
def create_broker(self):
def create_broker(self) -> TYPE_RESPONSE:
params = json.loads(self.body)
authentication_strategy = params.get("authenticationStrategy")
auto_minor_version_upgrade = params.get("autoMinorVersionUpgrade")
@ -119,7 +121,7 @@ class MQResponse(BaseResponse):
resp = {"brokerArn": broker_arn, "brokerId": broker_id}
return 200, {}, json.dumps(resp)
def update_broker(self):
def update_broker(self) -> TYPE_RESPONSE:
params = json.loads(self.body)
broker_id = self.path.split("/")[-1]
authentication_strategy = params.get("authenticationStrategy")
@ -145,23 +147,23 @@ class MQResponse(BaseResponse):
)
return self.describe_broker()
def delete_broker(self):
def delete_broker(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-1]
self.mq_backend.delete_broker(broker_id=broker_id)
return 200, {}, json.dumps(dict(brokerId=broker_id))
def describe_broker(self):
def describe_broker(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-1]
broker = self.mq_backend.describe_broker(broker_id=broker_id)
resp = broker.to_json()
resp["tags"] = self.mq_backend.list_tags(broker.arn)
return 200, {}, json.dumps(resp)
def list_brokers(self):
def list_brokers(self) -> TYPE_RESPONSE:
brokers = self.mq_backend.list_brokers()
return 200, {}, json.dumps(dict(brokerSummaries=[b.summary() for b in brokers]))
def create_user(self):
def create_user(self) -> TYPE_RESPONSE:
params = json.loads(self.body)
broker_id = self.path.split("/")[-3]
username = self.path.split("/")[-1]
@ -170,7 +172,7 @@ class MQResponse(BaseResponse):
self.mq_backend.create_user(broker_id, username, console_access, groups)
return 200, {}, "{}"
def update_user(self):
def update_user(self) -> TYPE_RESPONSE:
params = json.loads(self.body)
broker_id = self.path.split("/")[-3]
username = self.path.split("/")[-1]
@ -184,19 +186,19 @@ class MQResponse(BaseResponse):
)
return 200, {}, "{}"
def describe_user(self):
def describe_user(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-3]
username = self.path.split("/")[-1]
user = self.mq_backend.describe_user(broker_id, username)
return 200, {}, json.dumps(user.to_json())
def delete_user(self):
def delete_user(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-3]
username = self.path.split("/")[-1]
self.mq_backend.delete_user(broker_id, username)
return 200, {}, "{}"
def list_users(self):
def list_users(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-2]
users = self.mq_backend.list_users(broker_id=broker_id)
resp = {
@ -205,7 +207,7 @@ class MQResponse(BaseResponse):
}
return 200, {}, json.dumps(resp)
def create_configuration(self):
def create_configuration(self) -> TYPE_RESPONSE:
params = json.loads(self.body)
name = params.get("name")
engine_type = params.get("engineType")
@ -217,19 +219,19 @@ class MQResponse(BaseResponse):
)
return 200, {}, json.dumps(config.to_json())
def describe_configuration(self):
def describe_configuration(self) -> TYPE_RESPONSE:
config_id = self.path.split("/")[-1]
config = self.mq_backend.describe_configuration(config_id)
resp = config.to_json()
resp["tags"] = self.mq_backend.list_tags(config.arn)
return 200, {}, json.dumps(resp)
def list_configurations(self):
def list_configurations(self) -> TYPE_RESPONSE:
configs = self.mq_backend.list_configurations()
resp = {"configurations": [c.to_json() for c in configs]}
return 200, {}, json.dumps(resp)
def update_configuration(self):
def update_configuration(self) -> TYPE_RESPONSE:
config_id = self.path.split("/")[-1]
params = json.loads(self.body)
data = params.get("data")
@ -237,7 +239,7 @@ class MQResponse(BaseResponse):
config = self.mq_backend.update_configuration(config_id, data, description)
return 200, {}, json.dumps(config.to_json())
def get_configuration_revision(self):
def get_configuration_revision(self) -> TYPE_RESPONSE:
revision_id = self.path.split("/")[-1]
config_id = self.path.split("/")[-3]
revision = self.mq_backend.describe_configuration_revision(
@ -245,19 +247,19 @@ class MQResponse(BaseResponse):
)
return 200, {}, json.dumps(revision.to_json())
def create_tags(self):
def create_tags(self) -> TYPE_RESPONSE:
resource_arn = unquote(self.path.split("/")[-1])
tags = json.loads(self.body).get("tags", {})
self.mq_backend.create_tags(resource_arn, tags)
return 200, {}, "{}"
def delete_tags(self):
def delete_tags(self) -> TYPE_RESPONSE:
resource_arn = unquote(self.path.split("/")[-1])
tag_keys = self._get_param("tagKeys")
self.mq_backend.delete_tags(resource_arn, tag_keys)
return 200, {}, "{}"
def reboot(self, request, full_url, headers):
def reboot(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "POST":
broker_id = self.path.split("/")[-2]

View File

@ -235,7 +235,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/managedblockchain,moto/moto_api,moto/neptune,moto/opensearch,moto/rdsdata
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/neptune,moto/opensearch,moto/rdsdata
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract

View File

@ -205,6 +205,31 @@ def test_start_stop_flow_succeeds():
describe_response["Flow"]["Status"].should.equal("STANDBY")
@mock_mediaconnect
def test_unknown_flow():
client = boto3.client("mediaconnect", region_name=region)
with pytest.raises(ClientError) as exc:
client.describe_flow(FlowArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
with pytest.raises(ClientError) as exc:
client.delete_flow(FlowArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
with pytest.raises(ClientError) as exc:
client.start_flow(FlowArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
with pytest.raises(ClientError) as exc:
client.stop_flow(FlowArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
with pytest.raises(ClientError) as exc:
client.list_tags_for_resource(ResourceArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
@mock_mediaconnect
def test_tag_resource_succeeds():
client = boto3.client("mediaconnect", region_name=region)