diff --git a/moto/mediaconnect/exceptions.py b/moto/mediaconnect/exceptions.py index 6b75f85d7..6e991ab30 100644 --- a/moto/mediaconnect/exceptions.py +++ b/moto/mediaconnect/exceptions.py @@ -4,5 +4,5 @@ from moto.core.exceptions import JsonRESTError class NotFoundException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("NotFoundException", message) diff --git a/moto/mediaconnect/models.py b/moto/mediaconnect/models.py index 3d0f90729..3a5247d49 100644 --- a/moto/mediaconnect/models.py +++ b/moto/mediaconnect/models.py @@ -1,12 +1,15 @@ from collections import OrderedDict +from typing import Any, Dict, List, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.mediaconnect.exceptions import NotFoundException from moto.moto_api._internal import mock_random as random +from moto.utilities.tagging_service import TaggingService class Flow(BaseModel): - def __init__(self, **kwargs): + def __init__(self, account_id: str, region_name: str, **kwargs: Any): + self.id = random.uuid4().hex self.availability_zone = kwargs.get("availability_zone") self.entitlements = kwargs.get("entitlements", []) self.name = kwargs.get("name") @@ -15,17 +18,19 @@ class Flow(BaseModel): self.source_failover_config = kwargs.get("source_failover_config", {}) self.sources = kwargs.get("sources", []) self.vpc_interfaces = kwargs.get("vpc_interfaces", []) - self.status = "STANDBY" # one of 'STANDBY'|'ACTIVE'|'UPDATING'|'DELETING'|'STARTING'|'STOPPING'|'ERROR' - self._previous_status = None - self.description = None - self.flow_arn = None - self.egress_ip = None + self.status: Optional[ + str + ] = "STANDBY" # one of 'STANDBY'|'ACTIVE'|'UPDATING'|'DELETING'|'STARTING'|'STOPPING'|'ERROR' + self._previous_status: Optional[str] = None + self.description = "A Moto test flow" + self.flow_arn = f"arn:aws:mediaconnect:{region_name}:{account_id}:flow:{self.id}:{self.name}" + self.egress_ip = "127.0.0.1" if self.source and not self.sources: self.sources = [ self.source, ] - def to_dict(self, include=None): + def to_dict(self, include: Optional[List[str]] = None) -> Dict[str, Any]: data = { "availabilityZone": self.availability_zone, "description": self.description, @@ -47,7 +52,7 @@ class Flow(BaseModel): return new_data return data - def resolve_transient_states(self): + def resolve_transient_states(self) -> None: if self.status in ["STARTING"]: self.status = "ACTIVE" if self.status in ["STOPPING"]: @@ -57,26 +62,18 @@ class Flow(BaseModel): self._previous_status = None -class Resource(BaseModel): - def __init__(self, **kwargs): - self.resource_arn = kwargs.get("resource_arn") - self.tags = OrderedDict() - - def to_dict(self): - data = { - "resourceArn": self.resource_arn, - "tags": self.tags, - } - return data - - class MediaConnectBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._flows = OrderedDict() - self._resources = OrderedDict() + self._flows: Dict[str, Flow] = OrderedDict() + self.tagger = TaggingService() - def _add_source_details(self, source, flow_id, ingest_ip="127.0.0.1"): + def _add_source_details( + self, + source: Optional[Dict[str, Any]], + flow_id: str, + ingest_ip: str = "127.0.0.1", + ) -> None: if source: source["sourceArn"] = ( f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source" @@ -85,7 +82,9 @@ class MediaConnectBackend(BaseBackend): if not source.get("entitlementArn"): source["ingestIp"] = ingest_ip - def _add_entitlement_details(self, entitlement, entitlement_id): + def _add_entitlement_details( + self, entitlement: Optional[Dict[str, Any]], entitlement_id: str + ) -> None: if entitlement: entitlement["entitlementArn"] = ( f"arn:aws:mediaconnect:{self.region_name}" @@ -93,15 +92,9 @@ class MediaConnectBackend(BaseBackend): f":{entitlement['name']}" ) - def _create_flow_add_details(self, flow): - flow_id = random.uuid4().hex - - flow.description = "A Moto test flow" - flow.egress_ip = "127.0.0.1" - flow.flow_arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:flow:{flow_id}:{flow.name}" - + def _create_flow_add_details(self, flow: Flow) -> None: for index, _source in enumerate(flow.sources): - self._add_source_details(_source, flow_id, f"127.0.0.{index}") + self._add_source_details(_source, flow.id, f"127.0.0.{index}") for index, output in enumerate(flow.outputs or []): if output.get("protocol") in ["srt-listener", "zixi-pull"]: @@ -119,16 +112,18 @@ class MediaConnectBackend(BaseBackend): def create_flow( self, - availability_zone, - entitlements, - name, - outputs, - source, - source_failover_config, - sources, - vpc_interfaces, - ): + availability_zone: str, + entitlements: List[Dict[str, Any]], + name: str, + outputs: List[Dict[str, Any]], + source: Dict[str, Any], + source_failover_config: Dict[str, Any], + sources: List[Dict[str, Any]], + vpc_interfaces: List[Dict[str, Any]], + ) -> Flow: flow = Flow( + account_id=self.account_id, + region_name=self.region_name, availability_zone=availability_zone, entitlements=entitlements, name=name, @@ -142,11 +137,14 @@ class MediaConnectBackend(BaseBackend): self._flows[flow.flow_arn] = flow return flow - def list_flows(self, max_results, next_token): + def list_flows(self, max_results: Optional[int]) -> List[Dict[str, Any]]: + """ + Pagination is not yet implemented + """ flows = list(self._flows.values()) if max_results is not None: flows = flows[:max_results] - response_flows = [ + return [ fl.to_dict( include=[ "availabilityZone", @@ -159,74 +157,59 @@ class MediaConnectBackend(BaseBackend): ) for fl in flows ] - return response_flows, next_token - def describe_flow(self, flow_arn=None): - messages = {} + def describe_flow(self, flow_arn: str) -> Flow: if flow_arn in self._flows: flow = self._flows[flow_arn] flow.resolve_transient_states() - else: - raise NotFoundException(message="Flow not found.") - return flow.to_dict(), messages + return flow + raise NotFoundException(message="Flow not found.") - def delete_flow(self, flow_arn): + def delete_flow(self, flow_arn: str) -> Flow: if flow_arn in self._flows: - flow = self._flows[flow_arn] - del self._flows[flow_arn] - else: - raise NotFoundException(message="Flow not found.") - return flow_arn, flow.status + return self._flows.pop(flow_arn) + raise NotFoundException(message="Flow not found.") - def start_flow(self, flow_arn): + def start_flow(self, flow_arn: str) -> Flow: if flow_arn in self._flows: flow = self._flows[flow_arn] flow.status = "STARTING" - else: - raise NotFoundException(message="Flow not found.") - return flow_arn, flow.status + return flow + raise NotFoundException(message="Flow not found.") - def stop_flow(self, flow_arn): + def stop_flow(self, flow_arn: str) -> Flow: if flow_arn in self._flows: flow = self._flows[flow_arn] flow.status = "STOPPING" - else: - raise NotFoundException(message="Flow not found.") - return flow_arn, flow.status + return flow + raise NotFoundException(message="Flow not found.") - def tag_resource(self, resource_arn, tags): - if resource_arn in self._resources: - resource = self._resources[resource_arn] - else: - resource = Resource(resource_arn=resource_arn) - resource.tags.update(tags) - self._resources[resource_arn] = resource - return None + def tag_resource(self, resource_arn: str, tags: Dict[str, Any]) -> None: + tag_list = TaggingService.convert_dict_to_tags_input(tags) + self.tagger.tag_resource(resource_arn, tag_list) - def list_tags_for_resource(self, resource_arn): - if resource_arn in self._resources: - resource = self._resources[resource_arn] - else: - raise NotFoundException(message="Resource not found.") - return resource.tags + def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]: + if self.tagger.has_tags(resource_arn): + return self.tagger.get_tag_dict_for_resource(resource_arn) + raise NotFoundException(message="Resource not found.") - def add_flow_vpc_interfaces(self, flow_arn, vpc_interfaces): + def add_flow_vpc_interfaces( + self, flow_arn: str, vpc_interfaces: List[Dict[str, Any]] + ) -> Flow: if flow_arn in self._flows: flow = self._flows[flow_arn] flow.vpc_interfaces = vpc_interfaces - else: - raise NotFoundException(message=f"flow with arn={flow_arn} not found") - return flow_arn, flow.vpc_interfaces + return flow + raise NotFoundException(message=f"flow with arn={flow_arn} not found") - def add_flow_outputs(self, flow_arn, outputs): + def add_flow_outputs(self, flow_arn: str, outputs: List[Dict[str, Any]]) -> Flow: if flow_arn in self._flows: flow = self._flows[flow_arn] flow.outputs = outputs - else: - raise NotFoundException(message=f"flow with arn={flow_arn} not found") - return flow_arn, flow.outputs + return flow + raise NotFoundException(message=f"flow with arn={flow_arn} not found") - def remove_flow_vpc_interface(self, flow_arn, vpc_interface_name): + def remove_flow_vpc_interface(self, flow_arn: str, vpc_interface_name: str) -> None: if flow_arn in self._flows: flow = self._flows[flow_arn] flow.vpc_interfaces = [ @@ -236,9 +219,8 @@ class MediaConnectBackend(BaseBackend): ] else: raise NotFoundException(message=f"flow with arn={flow_arn} not found") - return flow_arn, vpc_interface_name - def remove_flow_output(self, flow_arn, output_name): + def remove_flow_output(self, flow_arn: str, output_name: str) -> None: if flow_arn in self._flows: flow = self._flows[flow_arn] flow.outputs = [ @@ -248,28 +230,27 @@ class MediaConnectBackend(BaseBackend): ] else: raise NotFoundException(message=f"flow with arn={flow_arn} not found") - return flow_arn, output_name def update_flow_output( self, - flow_arn, - output_arn, - cidr_allow_list, - description, - destination, - encryption, - max_latency, - media_stream_output_configuration, - min_latency, - port, - protocol, - remote_id, - sender_control_port, - sender_ip_address, - smoothing_latency, - stream_id, - vpc_interface_attachment, - ): + flow_arn: str, + output_arn: str, + cidr_allow_list: List[str], + description: str, + destination: str, + encryption: Dict[str, str], + max_latency: int, + media_stream_output_configuration: List[Dict[str, Any]], + min_latency: int, + port: int, + protocol: str, + remote_id: str, + sender_control_port: int, + sender_ip_address: str, + smoothing_latency: int, + stream_id: str, + vpc_interface_attachment: Dict[str, str], + ) -> Dict[str, Any]: if flow_arn not in self._flows: raise NotFoundException(message=f"flow with arn={flow_arn} not found") flow = self._flows[flow_arn] @@ -292,10 +273,12 @@ class MediaConnectBackend(BaseBackend): output["smoothingLatency"] = smoothing_latency output["streamId"] = stream_id output["vpcInterfaceAttachment"] = vpc_interface_attachment - return flow_arn, output + return output raise NotFoundException(message=f"output with arn={output_arn} not found") - def add_flow_sources(self, flow_arn, sources): + def add_flow_sources( + self, flow_arn: str, sources: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: if flow_arn not in self._flows: raise NotFoundException(message=f"flow with arn={flow_arn} not found") flow = self._flows[flow_arn] @@ -305,32 +288,32 @@ class MediaConnectBackend(BaseBackend): arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source:{source_id}:{name}" source["sourceArn"] = arn flow.sources = sources - return flow_arn, sources + return sources def update_flow_source( self, - flow_arn, - source_arn, - decryption, - description, - entitlement_arn, - ingest_port, - max_bitrate, - max_latency, - max_sync_buffer, - media_stream_source_configurations, - min_latency, - protocol, - sender_control_port, - sender_ip_address, - stream_id, - vpc_interface_name, - whitelist_cidr, - ): + flow_arn: str, + source_arn: str, + decryption: str, + description: str, + entitlement_arn: str, + ingest_port: int, + max_bitrate: int, + max_latency: int, + max_sync_buffer: int, + media_stream_source_configurations: List[Dict[str, Any]], + min_latency: int, + protocol: str, + sender_control_port: int, + sender_ip_address: str, + stream_id: str, + vpc_interface_name: str, + whitelist_cidr: str, + ) -> Optional[Dict[str, Any]]: if flow_arn not in self._flows: raise NotFoundException(message=f"flow with arn={flow_arn} not found") flow = self._flows[flow_arn] - source = next( + source: Optional[Dict[str, Any]] = next( iter( [source for source in flow.sources if source["sourceArn"] == source_arn] ), @@ -354,13 +337,13 @@ class MediaConnectBackend(BaseBackend): source["streamId"] = stream_id source["vpcInterfaceName"] = vpc_interface_name source["whitelistCidr"] = whitelist_cidr - return flow_arn, source + return source def grant_flow_entitlements( self, - flow_arn, - entitlements, - ): + flow_arn: str, + entitlements: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: if flow_arn not in self._flows: raise NotFoundException(message=f"flow with arn={flow_arn} not found") flow = self._flows[flow_arn] @@ -371,30 +354,30 @@ class MediaConnectBackend(BaseBackend): entitlement["entitlementArn"] = arn flow.entitlements += entitlements - return flow_arn, entitlements + return entitlements - def revoke_flow_entitlement(self, flow_arn, entitlement_arn): + def revoke_flow_entitlement(self, flow_arn: str, entitlement_arn: str) -> None: if flow_arn not in self._flows: raise NotFoundException(message=f"flow with arn={flow_arn} not found") flow = self._flows[flow_arn] for entitlement in flow.entitlements: if entitlement_arn == entitlement["entitlementArn"]: flow.entitlements.remove(entitlement) - return flow_arn, entitlement_arn + return raise NotFoundException( message=f"entitlement with arn={entitlement_arn} not found" ) def update_flow_entitlement( self, - flow_arn, - entitlement_arn, - description, - encryption, - entitlement_status, - name, - subscribers, - ): + flow_arn: str, + entitlement_arn: str, + description: str, + encryption: Dict[str, str], + entitlement_status: str, + name: str, + subscribers: List[str], + ) -> Dict[str, Any]: if flow_arn not in self._flows: raise NotFoundException(message=f"flow with arn={flow_arn} not found") flow = self._flows[flow_arn] @@ -405,12 +388,10 @@ class MediaConnectBackend(BaseBackend): entitlement["entitlementStatus"] = entitlement_status entitlement["name"] = name entitlement["subscribers"] = subscribers - return flow_arn, entitlement + return entitlement raise NotFoundException( message=f"entitlement with arn={entitlement_arn} not found" ) - # add methods from here - mediaconnect_backends = BackendDict(MediaConnectBackend, "mediaconnect") diff --git a/moto/mediaconnect/responses.py b/moto/mediaconnect/responses.py index cf0f76642..29438e726 100644 --- a/moto/mediaconnect/responses.py +++ b/moto/mediaconnect/responses.py @@ -1,20 +1,20 @@ import json from moto.core.responses import BaseResponse -from .models import mediaconnect_backends +from .models import mediaconnect_backends, MediaConnectBackend from urllib.parse import unquote class MediaConnectResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="mediaconnect") @property - def mediaconnect_backend(self): + def mediaconnect_backend(self) -> MediaConnectBackend: return mediaconnect_backends[self.current_account][self.region] - def create_flow(self): + def create_flow(self) -> str: availability_zone = self._get_param("availabilityZone") entitlements = self._get_param("entitlements") name = self._get_param("name") @@ -35,85 +35,79 @@ class MediaConnectResponse(BaseResponse): ) return json.dumps(dict(flow=flow.to_dict())) - def list_flows(self): + def list_flows(self) -> str: max_results = self._get_int_param("maxResults") - next_token = self._get_param("nextToken") - flows, next_token = self.mediaconnect_backend.list_flows( - max_results=max_results, next_token=next_token - ) - return json.dumps(dict(flows=flows, nextToken=next_token)) + flows = self.mediaconnect_backend.list_flows(max_results=max_results) + return json.dumps(dict(flows=flows)) - def describe_flow(self): + def describe_flow(self) -> str: flow_arn = unquote(self._get_param("flowArn")) - flow, messages = self.mediaconnect_backend.describe_flow(flow_arn=flow_arn) - return json.dumps(dict(flow=flow, messages=messages)) + flow = self.mediaconnect_backend.describe_flow(flow_arn=flow_arn) + return json.dumps(dict(flow=flow.to_dict())) - def delete_flow(self): + def delete_flow(self) -> str: flow_arn = unquote(self._get_param("flowArn")) - flow_arn, status = self.mediaconnect_backend.delete_flow(flow_arn=flow_arn) - return json.dumps(dict(flowArn=flow_arn, status=status)) + flow = self.mediaconnect_backend.delete_flow(flow_arn=flow_arn) + return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status)) - def start_flow(self): + def start_flow(self) -> str: flow_arn = unquote(self._get_param("flowArn")) - flow_arn, status = self.mediaconnect_backend.start_flow(flow_arn=flow_arn) - return json.dumps(dict(flowArn=flow_arn, status=status)) + flow = self.mediaconnect_backend.start_flow(flow_arn=flow_arn) + return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status)) - def stop_flow(self): + def stop_flow(self) -> str: flow_arn = unquote(self._get_param("flowArn")) - flow_arn, status = self.mediaconnect_backend.stop_flow(flow_arn=flow_arn) - return json.dumps(dict(flowArn=flow_arn, status=status)) + flow = self.mediaconnect_backend.stop_flow(flow_arn=flow_arn) + return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status)) - def tag_resource(self): + def tag_resource(self) -> str: resource_arn = unquote(self._get_param("resourceArn")) tags = self._get_param("tags") self.mediaconnect_backend.tag_resource(resource_arn=resource_arn, tags=tags) return json.dumps(dict()) - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> str: resource_arn = unquote(self._get_param("resourceArn")) - tags = self.mediaconnect_backend.list_tags_for_resource( - resource_arn=resource_arn - ) + tags = self.mediaconnect_backend.list_tags_for_resource(resource_arn) return json.dumps(dict(tags=tags)) - def add_flow_vpc_interfaces(self): + def add_flow_vpc_interfaces(self) -> str: flow_arn = unquote(self._get_param("flowArn")) vpc_interfaces = self._get_param("vpcInterfaces") - flow_arn, vpc_interfaces = self.mediaconnect_backend.add_flow_vpc_interfaces( + flow = self.mediaconnect_backend.add_flow_vpc_interfaces( flow_arn=flow_arn, vpc_interfaces=vpc_interfaces ) - return json.dumps(dict(flow_arn=flow_arn, vpc_interfaces=vpc_interfaces)) + return json.dumps( + dict(flow_arn=flow.flow_arn, vpc_interfaces=flow.vpc_interfaces) + ) - def remove_flow_vpc_interface(self): + def remove_flow_vpc_interface(self) -> str: flow_arn = unquote(self._get_param("flowArn")) vpc_interface_name = unquote(self._get_param("vpcInterfaceName")) - ( - flow_arn, - vpc_interface_name, - ) = self.mediaconnect_backend.remove_flow_vpc_interface( + self.mediaconnect_backend.remove_flow_vpc_interface( flow_arn=flow_arn, vpc_interface_name=vpc_interface_name ) return json.dumps( dict(flow_arn=flow_arn, vpc_interface_name=vpc_interface_name) ) - def add_flow_outputs(self): + def add_flow_outputs(self) -> str: flow_arn = unquote(self._get_param("flowArn")) outputs = self._get_param("outputs") - flow_arn, outputs = self.mediaconnect_backend.add_flow_outputs( + flow = self.mediaconnect_backend.add_flow_outputs( flow_arn=flow_arn, outputs=outputs ) - return json.dumps(dict(flow_arn=flow_arn, outputs=outputs)) + return json.dumps(dict(flow_arn=flow.flow_arn, outputs=flow.outputs)) - def remove_flow_output(self): + def remove_flow_output(self) -> str: flow_arn = unquote(self._get_param("flowArn")) output_name = unquote(self._get_param("outputArn")) - flow_arn, output_name = self.mediaconnect_backend.remove_flow_output( + self.mediaconnect_backend.remove_flow_output( flow_arn=flow_arn, output_name=output_name ) return json.dumps(dict(flow_arn=flow_arn, output_name=output_name)) - def update_flow_output(self): + def update_flow_output(self) -> str: flow_arn = unquote(self._get_param("flowArn")) output_arn = unquote(self._get_param("outputArn")) cidr_allow_list = self._get_param("cidrAllowList") @@ -133,7 +127,7 @@ class MediaConnectResponse(BaseResponse): smoothing_latency = self._get_param("smoothingLatency") stream_id = self._get_param("streamId") vpc_interface_attachment = self._get_param("vpcInterfaceAttachment") - flow_arn, output = self.mediaconnect_backend.update_flow_output( + output = self.mediaconnect_backend.update_flow_output( flow_arn=flow_arn, output_arn=output_arn, cidr_allow_list=cidr_allow_list, @@ -154,15 +148,15 @@ class MediaConnectResponse(BaseResponse): ) return json.dumps(dict(flowArn=flow_arn, output=output)) - def add_flow_sources(self): + def add_flow_sources(self) -> str: flow_arn = unquote(self._get_param("flowArn")) sources = self._get_param("sources") - flow_arn, sources = self.mediaconnect_backend.add_flow_sources( + sources = self.mediaconnect_backend.add_flow_sources( flow_arn=flow_arn, sources=sources ) return json.dumps(dict(flow_arn=flow_arn, sources=sources)) - def update_flow_source(self): + def update_flow_source(self) -> str: flow_arn = unquote(self._get_param("flowArn")) source_arn = unquote(self._get_param("sourceArn")) description = self._get_param("description") @@ -182,7 +176,7 @@ class MediaConnectResponse(BaseResponse): stream_id = self._get_param("streamId") vpc_interface_name = self._get_param("vpcInterfaceName") whitelist_cidr = self._get_param("whitelistCidr") - flow_arn, source = self.mediaconnect_backend.update_flow_source( + source = self.mediaconnect_backend.update_flow_source( flow_arn=flow_arn, source_arn=source_arn, decryption=decryption, @@ -203,23 +197,23 @@ class MediaConnectResponse(BaseResponse): ) return json.dumps(dict(flow_arn=flow_arn, source=source)) - def grant_flow_entitlements(self): + def grant_flow_entitlements(self) -> str: flow_arn = unquote(self._get_param("flowArn")) entitlements = self._get_param("entitlements") - flow_arn, entitlements = self.mediaconnect_backend.grant_flow_entitlements( + entitlements = self.mediaconnect_backend.grant_flow_entitlements( flow_arn=flow_arn, entitlements=entitlements ) return json.dumps(dict(flow_arn=flow_arn, entitlements=entitlements)) - def revoke_flow_entitlement(self): + def revoke_flow_entitlement(self) -> str: flow_arn = unquote(self._get_param("flowArn")) entitlement_arn = unquote(self._get_param("entitlementArn")) - flow_arn, entitlement_arn = self.mediaconnect_backend.revoke_flow_entitlement( + self.mediaconnect_backend.revoke_flow_entitlement( flow_arn=flow_arn, entitlement_arn=entitlement_arn ) return json.dumps(dict(flowArn=flow_arn, entitlementArn=entitlement_arn)) - def update_flow_entitlement(self): + def update_flow_entitlement(self) -> str: flow_arn = unquote(self._get_param("flowArn")) entitlement_arn = unquote(self._get_param("entitlementArn")) description = self._get_param("description") @@ -227,7 +221,7 @@ class MediaConnectResponse(BaseResponse): entitlement_status = self._get_param("entitlementStatus") name = self._get_param("name") subscribers = self._get_param("subscribers") - flow_arn, entitlement = self.mediaconnect_backend.update_flow_entitlement( + entitlement = self.mediaconnect_backend.update_flow_entitlement( flow_arn=flow_arn, entitlement_arn=entitlement_arn, description=description, diff --git a/moto/medialive/models.py b/moto/medialive/models.py index 3213b967c..82fc1b219 100644 --- a/moto/medialive/models.py +++ b/moto/medialive/models.py @@ -1,11 +1,12 @@ from collections import OrderedDict +from typing import Any, Dict, List, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api._internal import mock_random class Input(BaseModel): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.arn = kwargs.get("arn") self.attached_channels = kwargs.get("attached_channels", []) self.destinations = kwargs.get("destinations", []) @@ -23,8 +24,8 @@ class Input(BaseModel): self.tags = kwargs.get("tags") self.input_type = kwargs.get("input_type") - def to_dict(self): - data = { + def to_dict(self) -> Dict[str, Any]: + return { "arn": self.arn, "attachedChannels": self.attached_channels, "destinations": self.destinations, @@ -41,9 +42,8 @@ class Input(BaseModel): "tags": self.tags, "type": self.input_type, } - return data - def _resolve_transient_states(self): + def _resolve_transient_states(self) -> None: # Resolve transient states before second call # (to simulate AWS taking its sweet time with these things) if self.state in ["CREATING"]: @@ -53,7 +53,7 @@ class Input(BaseModel): class Channel(BaseModel): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.arn = kwargs.get("arn") self.cdi_input_specification = kwargs.get("cdi_input_specification") self.channel_class = kwargs.get("channel_class", "STANDARD") @@ -71,7 +71,7 @@ class Channel(BaseModel): self.tags = kwargs.get("tags") self._previous_state = None - def to_dict(self, exclude=None): + def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]: data = { "arn": self.arn, "cdiInputSpecification": self.cdi_input_specification, @@ -97,7 +97,7 @@ class Channel(BaseModel): del data[key] return data - def _resolve_transient_states(self): + def _resolve_transient_states(self) -> None: # Resolve transient states before second call # (to simulate AWS taking its sweet time with these things) if self.state in ["CREATING", "STOPPING"]: @@ -112,24 +112,24 @@ class Channel(BaseModel): class MediaLiveBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._channels = OrderedDict() - self._inputs = OrderedDict() + self._channels: Dict[str, Channel] = OrderedDict() + self._inputs: Dict[str, Input] = OrderedDict() def create_channel( self, - cdi_input_specification, - channel_class, - destinations, - encoder_settings, - input_attachments, - input_specification, - log_level, - name, - role_arn, - tags, - ): + cdi_input_specification: Dict[str, Any], + channel_class: str, + destinations: List[Dict[str, Any]], + encoder_settings: Dict[str, Any], + input_attachments: List[Dict[str, Any]], + input_specification: Dict[str, str], + log_level: str, + name: str, + role_arn: str, + tags: Dict[str, str], + ) -> Channel: """ The RequestID and Reserved parameters are not yet implemented """ @@ -155,47 +155,49 @@ class MediaLiveBackend(BaseBackend): self._channels[channel_id] = channel return channel - def list_channels(self, max_results, next_token): + def list_channels(self, max_results: Optional[int]) -> List[Dict[str, Any]]: + """ + Pagination is not yet implemented + """ channels = list(self._channels.values()) if max_results is not None: channels = channels[:max_results] - response_channels = [ + return [ c.to_dict(exclude=["encoderSettings", "pipelineDetails"]) for c in channels ] - return response_channels, next_token - def describe_channel(self, channel_id): + def describe_channel(self, channel_id: str) -> Channel: channel = self._channels[channel_id] channel._resolve_transient_states() - return channel.to_dict() + return channel - def delete_channel(self, channel_id): + def delete_channel(self, channel_id: str) -> Channel: channel = self._channels[channel_id] channel.state = "DELETING" - return channel.to_dict() + return channel - def start_channel(self, channel_id): + def start_channel(self, channel_id: str) -> Channel: channel = self._channels[channel_id] channel.state = "STARTING" - return channel.to_dict() + return channel - def stop_channel(self, channel_id): + def stop_channel(self, channel_id: str) -> Channel: channel = self._channels[channel_id] channel.state = "STOPPING" - return channel.to_dict() + return channel def update_channel( self, - channel_id, - cdi_input_specification, - destinations, - encoder_settings, - input_attachments, - input_specification, - log_level, - name, - role_arn, - ): + channel_id: str, + cdi_input_specification: Dict[str, str], + destinations: List[Dict[str, Any]], + encoder_settings: Dict[str, Any], + input_attachments: List[Dict[str, Any]], + input_specification: Dict[str, str], + log_level: str, + name: str, + role_arn: str, + ) -> Channel: channel = self._channels[channel_id] channel.cdi_input_specification = cdi_input_specification channel.destinations = destinations @@ -214,16 +216,16 @@ class MediaLiveBackend(BaseBackend): def create_input( self, - destinations, - input_devices, - input_security_groups, - media_connect_flows, - name, - role_arn, - sources, - tags, - input_type, - ): + destinations: List[Dict[str, str]], + input_devices: List[Dict[str, str]], + input_security_groups: List[str], + media_connect_flows: List[Dict[str, str]], + name: str, + role_arn: str, + sources: List[Dict[str, str]], + tags: Dict[str, str], + input_type: str, + ) -> Input: """ The VPC and RequestId parameters are not yet implemented """ @@ -246,34 +248,35 @@ class MediaLiveBackend(BaseBackend): self._inputs[input_id] = a_input return a_input - def describe_input(self, input_id): + def describe_input(self, input_id: str) -> Input: a_input = self._inputs[input_id] a_input._resolve_transient_states() - return a_input.to_dict() + return a_input - def list_inputs(self, max_results, next_token): + def list_inputs(self, max_results: Optional[int]) -> List[Dict[str, Any]]: + """ + Pagination is not yet implemented + """ inputs = list(self._inputs.values()) if max_results is not None: inputs = inputs[:max_results] - response_inputs = [i.to_dict() for i in inputs] - return response_inputs, next_token + return [i.to_dict() for i in inputs] - def delete_input(self, input_id): + def delete_input(self, input_id: str) -> None: a_input = self._inputs[input_id] a_input.state = "DELETING" - return a_input.to_dict() def update_input( self, - destinations, - input_devices, - input_id, - input_security_groups, - media_connect_flows, - name, - role_arn, - sources, - ): + destinations: List[Dict[str, str]], + input_devices: List[Dict[str, str]], + input_id: str, + input_security_groups: List[str], + media_connect_flows: List[Dict[str, str]], + name: str, + role_arn: str, + sources: List[Dict[str, str]], + ) -> Input: a_input = self._inputs[input_id] a_input.destinations = destinations a_input.input_devices = input_devices diff --git a/moto/medialive/responses.py b/moto/medialive/responses.py index f3a0832fe..571f3c957 100644 --- a/moto/medialive/responses.py +++ b/moto/medialive/responses.py @@ -1,17 +1,17 @@ from moto.core.responses import BaseResponse -from .models import medialive_backends +from .models import medialive_backends, MediaLiveBackend import json class MediaLiveResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="medialive") @property - def medialive_backend(self): + def medialive_backend(self) -> MediaLiveBackend: return medialive_backends[self.current_account][self.region] - def create_channel(self): + def create_channel(self) -> str: cdi_input_specification = self._get_param("cdiInputSpecification") channel_class = self._get_param("channelClass") destinations = self._get_param("destinations") @@ -39,34 +39,33 @@ class MediaLiveResponse(BaseResponse): dict(channel=channel.to_dict(exclude=["pipelinesRunningCount"])) ) - def list_channels(self): + def list_channels(self) -> str: max_results = self._get_int_param("maxResults") - next_token = self._get_param("nextToken") - channels, next_token = self.medialive_backend.list_channels( - max_results=max_results, next_token=next_token - ) + channels = self.medialive_backend.list_channels(max_results=max_results) - return json.dumps(dict(channels=channels, nextToken=next_token)) + return json.dumps(dict(channels=channels, nextToken=None)) - def describe_channel(self): + def describe_channel(self) -> str: channel_id = self._get_param("channelId") - return json.dumps( - self.medialive_backend.describe_channel(channel_id=channel_id) - ) + channel = self.medialive_backend.describe_channel(channel_id=channel_id) + return json.dumps(channel.to_dict()) - def delete_channel(self): + def delete_channel(self) -> str: channel_id = self._get_param("channelId") - return json.dumps(self.medialive_backend.delete_channel(channel_id=channel_id)) + channel = self.medialive_backend.delete_channel(channel_id=channel_id) + return json.dumps(channel.to_dict()) - def start_channel(self): + def start_channel(self) -> str: channel_id = self._get_param("channelId") - return json.dumps(self.medialive_backend.start_channel(channel_id=channel_id)) + channel = self.medialive_backend.start_channel(channel_id=channel_id) + return json.dumps(channel.to_dict()) - def stop_channel(self): + def stop_channel(self) -> str: channel_id = self._get_param("channelId") - return json.dumps(self.medialive_backend.stop_channel(channel_id=channel_id)) + channel = self.medialive_backend.stop_channel(channel_id=channel_id) + return json.dumps(channel.to_dict()) - def update_channel(self): + def update_channel(self) -> str: channel_id = self._get_param("channelId") cdi_input_specification = self._get_param("cdiInputSpecification") destinations = self._get_param("destinations") @@ -89,7 +88,7 @@ class MediaLiveResponse(BaseResponse): ) return json.dumps(dict(channel=channel.to_dict())) - def create_input(self): + def create_input(self) -> str: destinations = self._get_param("destinations") input_devices = self._get_param("inputDevices") input_security_groups = self._get_param("inputSecurityGroups") @@ -112,25 +111,23 @@ class MediaLiveResponse(BaseResponse): ) return json.dumps({"input": a_input.to_dict()}) - def describe_input(self): + def describe_input(self) -> str: input_id = self._get_param("inputId") - return json.dumps(self.medialive_backend.describe_input(input_id=input_id)) + a_input = self.medialive_backend.describe_input(input_id=input_id) + return json.dumps(a_input.to_dict()) - def list_inputs(self): + def list_inputs(self) -> str: max_results = self._get_int_param("maxResults") - next_token = self._get_param("nextToken") - inputs, next_token = self.medialive_backend.list_inputs( - max_results=max_results, next_token=next_token - ) + inputs = self.medialive_backend.list_inputs(max_results=max_results) - return json.dumps(dict(inputs=inputs, nextToken=next_token)) + return json.dumps(dict(inputs=inputs, nextToken=None)) - def delete_input(self): + def delete_input(self) -> str: input_id = self._get_param("inputId") self.medialive_backend.delete_input(input_id=input_id) return json.dumps({}) - def update_input(self): + def update_input(self) -> str: destinations = self._get_param("destinations") input_devices = self._get_param("inputDevices") input_id = self._get_param("inputId") diff --git a/moto/mediapackage/exceptions.py b/moto/mediapackage/exceptions.py index c52de621f..e4762a333 100644 --- a/moto/mediapackage/exceptions.py +++ b/moto/mediapackage/exceptions.py @@ -7,5 +7,5 @@ class MediaPackageClientError(JsonRESTError): # AWS service exceptions are caught with the underlying botocore exception, ClientError class ClientError(MediaPackageClientError): - def __init__(self, error, message): + def __init__(self, error: str, message: str): super().__init__(error, message) diff --git a/moto/mediapackage/models.py b/moto/mediapackage/models.py index 22a1bc66a..88db5b58c 100644 --- a/moto/mediapackage/models.py +++ b/moto/mediapackage/models.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from typing import Any, Dict, List from moto.core import BaseBackend, BackendDict, BaseModel @@ -6,27 +7,23 @@ from .exceptions import ClientError class Channel(BaseModel): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.arn = kwargs.get("arn") self.channel_id = kwargs.get("channel_id") self.description = kwargs.get("description") self.tags = kwargs.get("tags") - def to_dict(self, exclude=None): - data = { + def to_dict(self) -> Dict[str, Any]: + return { "arn": self.arn, "id": self.channel_id, "description": self.description, "tags": self.tags, } - if exclude: - for key in exclude: - del data[key] - return data class OriginEndpoint(BaseModel): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.arn = kwargs.get("arn") self.authorization = kwargs.get("authorization") self.channel_id = kwargs.get("channel_id") @@ -44,8 +41,8 @@ class OriginEndpoint(BaseModel): self.url = kwargs.get("url") self.whitelist = kwargs.get("whitelist") - def to_dict(self): - data = { + def to_dict(self) -> Dict[str, Any]: + return { "arn": self.arn, "authorization": self.authorization, "channelId": self.channel_id, @@ -63,69 +60,62 @@ class OriginEndpoint(BaseModel): "url": self.url, "whitelist": self.whitelist, } - return data class MediaPackageBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._channels = OrderedDict() - self._origin_endpoints = OrderedDict() + self._channels: Dict[str, Channel] = OrderedDict() + self._origin_endpoints: Dict[str, OriginEndpoint] = OrderedDict() - def create_channel(self, description, channel_id, tags): + def create_channel( + self, description: str, channel_id: str, tags: Dict[str, str] + ) -> Channel: arn = f"arn:aws:mediapackage:channel:{channel_id}" channel = Channel( arn=arn, description=description, - egress_access_logs={}, - hls_ingest={}, channel_id=channel_id, - ingress_access_logs={}, tags=tags, ) self._channels[channel_id] = channel return channel - def list_channels(self): - channels = list(self._channels.values()) - response_channels = [c.to_dict() for c in channels] - return response_channels + def list_channels(self) -> List[Dict[str, Any]]: + return [c.to_dict() for c in self._channels.values()] - def describe_channel(self, channel_id): + def describe_channel(self, channel_id: str) -> Channel: try: - channel = self._channels[channel_id] - return channel.to_dict() + return self._channels[channel_id] except KeyError: - error = "NotFoundException" - raise ClientError(error, f"channel with id={channel_id} not found") + raise ClientError( + "NotFoundException", f"channel with id={channel_id} not found" + ) - def delete_channel(self, channel_id): - try: - channel = self._channels[channel_id] - del self._channels[channel_id] - return channel.to_dict() - - except KeyError: - error = "NotFoundException" - raise ClientError(error, f"channel with id={channel_id} not found") + def delete_channel(self, channel_id: str) -> Channel: + if channel_id in self._channels: + return self._channels.pop(channel_id) + raise ClientError( + "NotFoundException", f"channel with id={channel_id} not found" + ) def create_origin_endpoint( self, - authorization, - channel_id, - cmaf_package, - dash_package, - description, - hls_package, - endpoint_id, - manifest_name, - mss_package, - origination, - startover_window_seconds, - tags, - time_delay_seconds, - whitelist, - ): + authorization: Dict[str, str], + channel_id: str, + cmaf_package: Dict[str, Any], + dash_package: Dict[str, Any], + description: str, + hls_package: Dict[str, Any], + endpoint_id: str, + manifest_name: str, + mss_package: Dict[str, Any], + origination: str, + startover_window_seconds: int, + tags: Dict[str, str], + time_delay_seconds: int, + whitelist: List[str], + ) -> OriginEndpoint: arn = f"arn:aws:mediapackage:origin_endpoint:{endpoint_id}" url = f"https://origin-endpoint.mediapackage.{self.region_name}.amazonaws.com/{endpoint_id}" origin_endpoint = OriginEndpoint( @@ -149,43 +139,39 @@ class MediaPackageBackend(BaseBackend): self._origin_endpoints[endpoint_id] = origin_endpoint return origin_endpoint - def describe_origin_endpoint(self, endpoint_id): + def describe_origin_endpoint(self, endpoint_id: str) -> OriginEndpoint: try: - origin_endpoint = self._origin_endpoints[endpoint_id] - return origin_endpoint.to_dict() + return self._origin_endpoints[endpoint_id] except KeyError: - error = "NotFoundException" - raise ClientError(error, f"origin endpoint with id={endpoint_id} not found") + raise ClientError( + "NotFoundException", f"origin endpoint with id={endpoint_id} not found" + ) - def list_origin_endpoints(self): - origin_endpoints = list(self._origin_endpoints.values()) - response_origin_endpoints = [o.to_dict() for o in origin_endpoints] - return response_origin_endpoints + def list_origin_endpoints(self) -> List[Dict[str, Any]]: + return [o.to_dict() for o in self._origin_endpoints.values()] - def delete_origin_endpoint(self, endpoint_id): - try: - origin_endpoint = self._origin_endpoints[endpoint_id] - del self._origin_endpoints[endpoint_id] - return origin_endpoint.to_dict() - except KeyError: - error = "NotFoundException" - raise ClientError(error, f"origin endpoint with id={endpoint_id} not found") + def delete_origin_endpoint(self, endpoint_id: str) -> OriginEndpoint: + if endpoint_id in self._origin_endpoints: + return self._origin_endpoints.pop(endpoint_id) + raise ClientError( + "NotFoundException", f"origin endpoint with id={endpoint_id} not found" + ) def update_origin_endpoint( self, - authorization, - cmaf_package, - dash_package, - description, - hls_package, - endpoint_id, - manifest_name, - mss_package, - origination, - startover_window_seconds, - time_delay_seconds, - whitelist, - ): + authorization: Dict[str, str], + cmaf_package: Dict[str, Any], + dash_package: Dict[str, Any], + description: str, + hls_package: Dict[str, Any], + endpoint_id: str, + manifest_name: str, + mss_package: Dict[str, Any], + origination: str, + startover_window_seconds: int, + time_delay_seconds: int, + whitelist: List[str], + ) -> OriginEndpoint: try: origin_endpoint = self._origin_endpoints[endpoint_id] origin_endpoint.authorization = authorization @@ -200,10 +186,10 @@ class MediaPackageBackend(BaseBackend): origin_endpoint.time_delay_seconds = time_delay_seconds origin_endpoint.whitelist = whitelist return origin_endpoint - except KeyError: - error = "NotFoundException" - raise ClientError(error, f"origin endpoint with id={endpoint_id} not found") + raise ClientError( + "NotFoundException", f"origin endpoint with id={endpoint_id} not found" + ) mediapackage_backends = BackendDict(MediaPackageBackend, "mediapackage") diff --git a/moto/mediapackage/responses.py b/moto/mediapackage/responses.py index f3d85d7e7..082e40f39 100644 --- a/moto/mediapackage/responses.py +++ b/moto/mediapackage/responses.py @@ -1,17 +1,17 @@ from moto.core.responses import BaseResponse -from .models import mediapackage_backends +from .models import mediapackage_backends, MediaPackageBackend import json class MediaPackageResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="mediapackage") @property - def mediapackage_backend(self): + def mediapackage_backend(self) -> MediaPackageBackend: return mediapackage_backends[self.current_account][self.region] - def create_channel(self): + def create_channel(self) -> str: description = self._get_param("description") channel_id = self._get_param("id") tags = self._get_param("tags") @@ -20,23 +20,21 @@ class MediaPackageResponse(BaseResponse): ) return json.dumps(channel.to_dict()) - def list_channels(self): + def list_channels(self) -> str: channels = self.mediapackage_backend.list_channels() return json.dumps(dict(channels=channels)) - def describe_channel(self): + def describe_channel(self) -> str: channel_id = self._get_param("id") - return json.dumps( - self.mediapackage_backend.describe_channel(channel_id=channel_id) - ) + channel = self.mediapackage_backend.describe_channel(channel_id=channel_id) + return json.dumps(channel.to_dict()) - def delete_channel(self): + def delete_channel(self) -> str: channel_id = self._get_param("id") - return json.dumps( - self.mediapackage_backend.delete_channel(channel_id=channel_id) - ) + channel = self.mediapackage_backend.delete_channel(channel_id=channel_id) + return json.dumps(channel.to_dict()) - def create_origin_endpoint(self): + def create_origin_endpoint(self) -> str: authorization = self._get_param("authorization") channel_id = self._get_param("channelId") cmaf_package = self._get_param("cmafPackage") @@ -65,27 +63,29 @@ class MediaPackageResponse(BaseResponse): startover_window_seconds=startover_window_seconds, tags=tags, time_delay_seconds=time_delay_seconds, - whitelist=whitelist, + whitelist=whitelist, # type: ignore[arg-type] ) return json.dumps(origin_endpoint.to_dict()) - def list_origin_endpoints(self): + def list_origin_endpoints(self) -> str: origin_endpoints = self.mediapackage_backend.list_origin_endpoints() return json.dumps(dict(originEndpoints=origin_endpoints)) - def describe_origin_endpoint(self): + def describe_origin_endpoint(self) -> str: endpoint_id = self._get_param("id") - return json.dumps( - self.mediapackage_backend.describe_origin_endpoint(endpoint_id=endpoint_id) + endpoint = self.mediapackage_backend.describe_origin_endpoint( + endpoint_id=endpoint_id ) + return json.dumps(endpoint.to_dict()) - def delete_origin_endpoint(self): + def delete_origin_endpoint(self) -> str: endpoint_id = self._get_param("id") - return json.dumps( - self.mediapackage_backend.delete_origin_endpoint(endpoint_id=endpoint_id) + endpoint = self.mediapackage_backend.delete_origin_endpoint( + endpoint_id=endpoint_id ) + return json.dumps(endpoint.to_dict()) - def update_origin_endpoint(self): + def update_origin_endpoint(self) -> str: authorization = self._get_param("authorization") cmaf_package = self._get_param("cmafPackage") dash_package = self._get_param("dashPackage") @@ -110,6 +110,6 @@ class MediaPackageResponse(BaseResponse): origination=origination, startover_window_seconds=startover_window_seconds, time_delay_seconds=time_delay_seconds, - whitelist=whitelist, + whitelist=whitelist, # type: ignore[arg-type] ) return json.dumps(origin_endpoint.to_dict()) diff --git a/moto/mediastore/exceptions.py b/moto/mediastore/exceptions.py index dffb2b6e6..8967ea328 100644 --- a/moto/mediastore/exceptions.py +++ b/moto/mediastore/exceptions.py @@ -1,3 +1,4 @@ +from typing import Optional from moto.core.exceptions import JsonRESTError @@ -6,7 +7,7 @@ class MediaStoreClientError(JsonRESTError): class ContainerNotFoundException(MediaStoreClientError): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): self.code = 400 super().__init__( "ContainerNotFoundException", @@ -15,7 +16,7 @@ class ContainerNotFoundException(MediaStoreClientError): class ResourceNotFoundException(MediaStoreClientError): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): self.code = 400 super().__init__( "ResourceNotFoundException", msg or "The specified container does not exist" @@ -23,7 +24,7 @@ class ResourceNotFoundException(MediaStoreClientError): class PolicyNotFoundException(MediaStoreClientError): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): self.code = 400 super().__init__( "PolicyNotFoundException", diff --git a/moto/mediastore/models.py b/moto/mediastore/models.py index 5ef8c7b28..7dd08c7d7 100644 --- a/moto/mediastore/models.py +++ b/moto/mediastore/models.py @@ -1,5 +1,6 @@ from collections import OrderedDict from datetime import date +from typing import Any, Dict, List, Optional from moto.core import BaseBackend, BackendDict, BaseModel from .exceptions import ( @@ -10,18 +11,18 @@ from .exceptions import ( class Container(BaseModel): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.arn = kwargs.get("arn") self.name = kwargs.get("name") self.endpoint = kwargs.get("endpoint") self.status = kwargs.get("status") self.creation_time = kwargs.get("creation_time") - self.lifecycle_policy = None - self.policy = None - self.metric_policy = None + self.lifecycle_policy: Optional[str] = None + self.policy: Optional[str] = None + self.metric_policy: Optional[str] = None self.tags = kwargs.get("tags") - def to_dict(self, exclude=None): + def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]: data = { "ARN": self.arn, "Name": self.name, @@ -37,11 +38,11 @@ class Container(BaseModel): class MediaStoreBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._containers = OrderedDict() + self._containers: Dict[str, Container] = OrderedDict() - def create_container(self, name, tags): + def create_container(self, name: str, tags: Dict[str, str]) -> Container: arn = f"arn:aws:mediastore:container:{name}" container = Container( arn=arn, @@ -54,40 +55,36 @@ class MediaStoreBackend(BaseBackend): self._containers[name] = container return container - def delete_container(self, name): + def delete_container(self, name: str) -> None: if name not in self._containers: raise ContainerNotFoundException() del self._containers[name] - return {} - def describe_container(self, name): + def describe_container(self, name: str) -> Container: if name not in self._containers: raise ResourceNotFoundException() container = self._containers[name] container.status = "ACTIVE" return container - def list_containers(self): + def list_containers(self) -> List[Dict[str, Any]]: """ Pagination is not yet implemented """ - containers = list(self._containers.values()) - response_containers = [c.to_dict() for c in containers] - return response_containers, None + return [c.to_dict() for c in self._containers.values()] - def list_tags_for_resource(self, name): + def list_tags_for_resource(self, name: str) -> Optional[Dict[str, str]]: if name not in self._containers: raise ContainerNotFoundException() tags = self._containers[name].tags return tags - def put_lifecycle_policy(self, container_name, lifecycle_policy): + def put_lifecycle_policy(self, container_name: str, lifecycle_policy: str) -> None: if container_name not in self._containers: raise ResourceNotFoundException() self._containers[container_name].lifecycle_policy = lifecycle_policy - return {} - def get_lifecycle_policy(self, container_name): + def get_lifecycle_policy(self, container_name: str) -> str: if container_name not in self._containers: raise ResourceNotFoundException() lifecycle_policy = self._containers[container_name].lifecycle_policy @@ -95,13 +92,12 @@ class MediaStoreBackend(BaseBackend): raise PolicyNotFoundException() return lifecycle_policy - def put_container_policy(self, container_name, policy): + def put_container_policy(self, container_name: str, policy: str) -> None: if container_name not in self._containers: raise ResourceNotFoundException() self._containers[container_name].policy = policy - return {} - def get_container_policy(self, container_name): + def get_container_policy(self, container_name: str) -> str: if container_name not in self._containers: raise ResourceNotFoundException() policy = self._containers[container_name].policy @@ -109,13 +105,12 @@ class MediaStoreBackend(BaseBackend): raise PolicyNotFoundException() return policy - def put_metric_policy(self, container_name, metric_policy): + def put_metric_policy(self, container_name: str, metric_policy: str) -> None: if container_name not in self._containers: raise ResourceNotFoundException() self._containers[container_name].metric_policy = metric_policy - return {} - def get_metric_policy(self, container_name): + def get_metric_policy(self, container_name: str) -> str: if container_name not in self._containers: raise ResourceNotFoundException() metric_policy = self._containers[container_name].metric_policy diff --git a/moto/mediastore/responses.py b/moto/mediastore/responses.py index ecb90f779..7ba3862f1 100644 --- a/moto/mediastore/responses.py +++ b/moto/mediastore/responses.py @@ -1,73 +1,73 @@ import json from moto.core.responses import BaseResponse -from .models import mediastore_backends +from .models import mediastore_backends, MediaStoreBackend class MediaStoreResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="mediastore") @property - def mediastore_backend(self): + def mediastore_backend(self) -> MediaStoreBackend: return mediastore_backends[self.current_account][self.region] - def create_container(self): + def create_container(self) -> str: name = self._get_param("ContainerName") tags = self._get_param("Tags") container = self.mediastore_backend.create_container(name=name, tags=tags) return json.dumps(dict(Container=container.to_dict())) - def delete_container(self): + def delete_container(self) -> str: name = self._get_param("ContainerName") - result = self.mediastore_backend.delete_container(name=name) - return json.dumps(result) + self.mediastore_backend.delete_container(name=name) + return "{}" - def describe_container(self): + def describe_container(self) -> str: name = self._get_param("ContainerName") container = self.mediastore_backend.describe_container(name=name) return json.dumps(dict(Container=container.to_dict())) - def list_containers(self): - containers, next_token = self.mediastore_backend.list_containers() - return json.dumps(dict(dict(Containers=containers), NextToken=next_token)) + def list_containers(self) -> str: + containers = self.mediastore_backend.list_containers() + return json.dumps(dict(dict(Containers=containers), NextToken=None)) - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> str: name = self._get_param("Resource") tags = self.mediastore_backend.list_tags_for_resource(name) return json.dumps(dict(Tags=tags)) - def put_lifecycle_policy(self): + def put_lifecycle_policy(self) -> str: container_name = self._get_param("ContainerName") lifecycle_policy = self._get_param("LifecyclePolicy") - policy = self.mediastore_backend.put_lifecycle_policy( + self.mediastore_backend.put_lifecycle_policy( container_name=container_name, lifecycle_policy=lifecycle_policy ) - return json.dumps(policy) + return "{}" - def get_lifecycle_policy(self): + def get_lifecycle_policy(self) -> str: container_name = self._get_param("ContainerName") lifecycle_policy = self.mediastore_backend.get_lifecycle_policy( container_name=container_name ) return json.dumps(dict(LifecyclePolicy=lifecycle_policy)) - def put_container_policy(self): + def put_container_policy(self) -> str: container_name = self._get_param("ContainerName") policy = self._get_param("Policy") - container_policy = self.mediastore_backend.put_container_policy( + self.mediastore_backend.put_container_policy( container_name=container_name, policy=policy ) - return json.dumps(container_policy) + return "{}" - def get_container_policy(self): + def get_container_policy(self) -> str: container_name = self._get_param("ContainerName") policy = self.mediastore_backend.get_container_policy( container_name=container_name ) return json.dumps(dict(Policy=policy)) - def put_metric_policy(self): + def put_metric_policy(self) -> str: container_name = self._get_param("ContainerName") metric_policy = self._get_param("MetricPolicy") self.mediastore_backend.put_metric_policy( @@ -75,12 +75,9 @@ class MediaStoreResponse(BaseResponse): ) return json.dumps(metric_policy) - def get_metric_policy(self): + def get_metric_policy(self) -> str: container_name = self._get_param("ContainerName") metric_policy = self.mediastore_backend.get_metric_policy( container_name=container_name ) return json.dumps(dict(MetricPolicy=metric_policy)) - - -# add templates from here diff --git a/moto/mediastoredata/exceptions.py b/moto/mediastoredata/exceptions.py index e1c3b9674..433753cc5 100644 --- a/moto/mediastoredata/exceptions.py +++ b/moto/mediastoredata/exceptions.py @@ -7,5 +7,5 @@ class MediaStoreDataClientError(JsonRESTError): # AWS service exceptions are caught with the underlying botocore exception, ClientError class ClientError(MediaStoreDataClientError): - def __init__(self, error, message): + def __init__(self, error: str, message: str): super().__init__(error, message) diff --git a/moto/mediastoredata/models.py b/moto/mediastoredata/models.py index 7185ea9d2..f04ec3a1b 100644 --- a/moto/mediastoredata/models.py +++ b/moto/mediastoredata/models.py @@ -1,20 +1,23 @@ import hashlib from collections import OrderedDict +from typing import Any, Dict, List from moto.core import BaseBackend, BackendDict, BaseModel from .exceptions import ClientError class Object(BaseModel): - def __init__(self, path, body, etag, storage_class="TEMPORAL"): + def __init__( + self, path: str, body: str, etag: str, storage_class: str = "TEMPORAL" + ): self.path = path self.body = body self.content_sha256 = hashlib.sha256(body.encode("utf-8")).hexdigest() self.etag = etag self.storage_class = storage_class - def to_dict(self): - data = { + def to_dict(self) -> Dict[str, Any]: + return { "ETag": self.etag, "Name": self.path, "Type": "FILE", @@ -24,15 +27,15 @@ class Object(BaseModel): "ContentSHA256": self.content_sha256, } - return data - class MediaStoreDataBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._objects = OrderedDict() + self._objects: Dict[str, Object] = OrderedDict() - def put_object(self, body, path, storage_class="TEMPORAL"): + def put_object( + self, body: str, path: str, storage_class: str = "TEMPORAL" + ) -> Object: """ The following parameters are not yet implemented: ContentType, CacheControl, UploadAvailability """ @@ -42,30 +45,29 @@ class MediaStoreDataBackend(BaseBackend): self._objects[path] = new_object return new_object - def delete_object(self, path): + def delete_object(self, path: str) -> None: if path not in self._objects: - error = "ObjectNotFoundException" - raise ClientError(error, f"Object with id={path} not found") + raise ClientError( + "ObjectNotFoundException", f"Object with id={path} not found" + ) del self._objects[path] - return {} - def get_object(self, path): + def get_object(self, path: str) -> Object: """ The Range-parameter is not yet supported. """ objects_found = [item for item in self._objects.values() if item.path == path] if len(objects_found) == 0: - error = "ObjectNotFoundException" - raise ClientError(error, f"Object with id={path} not found") + raise ClientError( + "ObjectNotFoundException", f"Object with id={path} not found" + ) return objects_found[0] - def list_items(self): + def list_items(self) -> List[Dict[str, Any]]: """ The Path- and MaxResults-parameters are not yet supported. """ - items = self._objects.values() - response_items = [c.to_dict() for c in items] - return response_items + return [c.to_dict() for c in self._objects.values()] mediastoredata_backends = BackendDict(MediaStoreDataBackend, "mediastore-data") diff --git a/moto/mediastoredata/responses.py b/moto/mediastoredata/responses.py index 8e3251a17..48503852e 100644 --- a/moto/mediastoredata/responses.py +++ b/moto/mediastoredata/responses.py @@ -1,35 +1,36 @@ import json +from typing import Dict, Tuple from moto.core.responses import BaseResponse -from .models import mediastoredata_backends +from .models import mediastoredata_backends, MediaStoreDataBackend class MediaStoreDataResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="mediastore-data") @property - def mediastoredata_backend(self): + def mediastoredata_backend(self) -> MediaStoreDataBackend: return mediastoredata_backends[self.current_account][self.region] - def get_object(self): + def get_object(self) -> Tuple[str, Dict[str, str]]: path = self._get_param("Path") result = self.mediastoredata_backend.get_object(path=path) headers = {"Path": result.path} return result.body, headers - def put_object(self): + def put_object(self) -> str: body = self.body path = self._get_param("Path") new_object = self.mediastoredata_backend.put_object(body, path) object_dict = new_object.to_dict() return json.dumps(object_dict) - def delete_object(self): + def delete_object(self) -> str: item_id = self._get_param("Path") - result = self.mediastoredata_backend.delete_object(path=item_id) - return json.dumps(result) + self.mediastoredata_backend.delete_object(path=item_id) + return "{}" - def list_items(self): + def list_items(self) -> str: items = self.mediastoredata_backend.list_items() return json.dumps(dict(Items=items)) diff --git a/moto/meteringmarketplace/exceptions.py b/moto/meteringmarketplace/exceptions.py index 188e01549..e69de29bb 100644 --- a/moto/meteringmarketplace/exceptions.py +++ b/moto/meteringmarketplace/exceptions.py @@ -1,38 +0,0 @@ -from moto.core.exceptions import JsonRESTError - - -class DisabledApiException(JsonRESTError): - def __init__(self, message): - super().__init__(error_type="DisabledApiException", message=message) - - -class InternalServiceErrorException(JsonRESTError): - def __init__(self, message): - super().__init__(error_type="InternalServiceErrorException", message=message) - - -class InvalidCustomerIdentifierException(JsonRESTError): - def __init__(self, message): - super().__init__( - error_type="InvalidCustomerIdentifierException", message=message - ) - - -class InvalidProductCodeException(JsonRESTError): - def __init__(self, message): - super().__init__(error_type="InvalidProductCodeException", message=message) - - -class InvalidUsageDimensionException(JsonRESTError): - def __init__(self, message): - super().__init__(error_type="InvalidUsageDimensionException", message=message) - - -class ThrottlingException(JsonRESTError): - def __init__(self, message): - super().__init__(error_type="ThrottlingException", message=message) - - -class TimestampOutOfBoundsException(JsonRESTError): - def __init__(self, message): - super().__init__(error_type="TimestampOutOfBoundsException", message=message) diff --git a/moto/meteringmarketplace/models.py b/moto/meteringmarketplace/models.py index 5d3933b84..e50cb463a 100644 --- a/moto/meteringmarketplace/models.py +++ b/moto/meteringmarketplace/models.py @@ -1,10 +1,17 @@ import collections +from typing import Any, Deque, Dict, List from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api._internal import mock_random -class UsageRecord(BaseModel, dict): - def __init__(self, timestamp, customer_identifier, dimension, quantity=0): +class UsageRecord(BaseModel, Dict[str, Any]): # type: ignore[misc] + def __init__( + self, + timestamp: str, + customer_identifier: str, + dimension: str, + quantity: int = 0, + ): super().__init__() self.timestamp = timestamp self.customer_identifier = customer_identifier @@ -12,54 +19,45 @@ class UsageRecord(BaseModel, dict): self.quantity = quantity self.metering_record_id = mock_random.uuid4().hex - @classmethod - def from_data(cls, data): - cls( - timestamp=data.get("Timestamp"), - customer_identifier=data.get("CustomerIdentifier"), - dimension=data.get("Dimension"), - quantity=data.get("Quantity", 0), - ) - @property - def timestamp(self): + def timestamp(self) -> str: return self["Timestamp"] @timestamp.setter - def timestamp(self, value): + def timestamp(self, value: str) -> None: self["Timestamp"] = value @property - def customer_identifier(self): + def customer_identifier(self) -> str: return self["CustomerIdentifier"] @customer_identifier.setter - def customer_identifier(self, value): + def customer_identifier(self, value: str) -> None: self["CustomerIdentifier"] = value @property - def dimension(self): + def dimension(self) -> str: return self["Dimension"] @dimension.setter - def dimension(self, value): + def dimension(self, value: str) -> None: self["Dimension"] = value @property - def quantity(self): + def quantity(self) -> int: return self["Quantity"] @quantity.setter - def quantity(self, value): + def quantity(self, value: int) -> None: self["Quantity"] = value -class Result(BaseModel, dict): +class Result(BaseModel, Dict[str, Any]): # type: ignore[misc] SUCCESS = "Success" CUSTOMER_NOT_SUBSCRIBED = "CustomerNotSubscribed" DUPLICATE_RECORD = "DuplicateRecord" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.usage_record = UsageRecord( timestamp=kwargs["Timestamp"], customer_identifier=kwargs["CustomerIdentifier"], @@ -70,28 +68,26 @@ class Result(BaseModel, dict): self["MeteringRecordId"] = self.usage_record.metering_record_id @property - def metering_record_id(self): + def metering_record_id(self) -> str: return self["MeteringRecordId"] @property - def status(self): + def status(self) -> str: return self["Status"] @status.setter - def status(self, value): + def status(self, value: str) -> None: self["Status"] = value @property - def usage_record(self): + def usage_record(self) -> UsageRecord: return self["UsageRecord"] @usage_record.setter - def usage_record(self, value): - if not isinstance(value, UsageRecord): - value = UsageRecord.from_data(value) + def usage_record(self, value: UsageRecord) -> None: self["UsageRecord"] = value - def is_duplicate(self, other): + def is_duplicate(self, other: Any) -> bool: """ DuplicateRecord - Indicates that the UsageRecord was invalid and not honored. A previously metered UsageRecord had the same customer, dimension, and time, @@ -107,23 +103,29 @@ class Result(BaseModel, dict): ) -class CustomerDeque(collections.deque): - def is_subscribed(self, customer): +class CustomerDeque(Deque[str]): + def is_subscribed(self, customer: str) -> bool: return customer in self -class ResultDeque(collections.deque): - def is_duplicate(self, result): +class ResultDeque(Deque[Result]): + def is_duplicate(self, result: Result) -> bool: return any(record.is_duplicate(result) for record in self) class MeteringMarketplaceBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.customers_by_product = collections.defaultdict(CustomerDeque) - self.records_by_product = collections.defaultdict(ResultDeque) + self.customers_by_product: Dict[str, CustomerDeque] = collections.defaultdict( + CustomerDeque + ) + self.records_by_product: Dict[str, ResultDeque] = collections.defaultdict( + ResultDeque + ) - def batch_meter_usage(self, product_code, usage_records): + def batch_meter_usage( + self, product_code: str, usage_records: List[Dict[str, Any]] + ) -> List[Result]: results = [] for usage in usage_records: result = Result(**usage) diff --git a/moto/meteringmarketplace/responses.py b/moto/meteringmarketplace/responses.py index d3addb966..9d14cbab4 100644 --- a/moto/meteringmarketplace/responses.py +++ b/moto/meteringmarketplace/responses.py @@ -9,8 +9,7 @@ class MarketplaceMeteringResponse(BaseResponse): def backend(self) -> MeteringMarketplaceBackend: return meteringmarketplace_backends[self.current_account][self.region] - def batch_meter_usage(self): - results = [] + def batch_meter_usage(self) -> str: usage_records = json.loads(self.body)["UsageRecords"] product_code = json.loads(self.body)["ProductCode"] results = self.backend.batch_meter_usage(product_code, usage_records) diff --git a/moto/moto_server/threaded_moto_server.py b/moto/moto_server/threaded_moto_server.py index 70333e282..602f06cec 100644 --- a/moto/moto_server/threaded_moto_server.py +++ b/moto/moto_server/threaded_moto_server.py @@ -1,29 +1,32 @@ import time from threading import Thread -from werkzeug.serving import make_server +from typing import Optional +from werkzeug.serving import make_server, BaseWSGIServer from .werkzeug_app import DomainDispatcherApplication, create_backend_app class ThreadedMotoServer: - def __init__(self, ip_address="0.0.0.0", port=5000, verbose=True): + def __init__( + self, ip_address: str = "0.0.0.0", port: int = 5000, verbose: bool = True + ): self._port = port - self._thread = None + self._thread: Optional[Thread] = None self._ip_address = ip_address - self._server = None + self._server: Optional[BaseWSGIServer] = None self._server_ready = False self._verbose = verbose - def _server_entry(self): + def _server_entry(self) -> None: app = DomainDispatcherApplication(create_backend_app) self._server = make_server(self._ip_address, self._port, app, True) self._server_ready = True self._server.serve_forever() - def start(self): + def start(self) -> None: if self._verbose: print( # noqa f"Starting a new Thread with MotoServer running on {self._ip_address}:{self._port}..." @@ -33,9 +36,9 @@ class ThreadedMotoServer: while not self._server_ready: time.sleep(0.1) - def stop(self): + def stop(self) -> None: self._server_ready = False if self._server: self._server.shutdown() - self._thread.join() + self._thread.join() # type: ignore[union-attr] diff --git a/moto/moto_server/utilities.py b/moto/moto_server/utilities.py index 3753b502b..bf4268b1c 100644 --- a/moto/moto_server/utilities.py +++ b/moto/moto_server/utilities.py @@ -1,6 +1,6 @@ import json from flask.testing import FlaskClient - +from typing import Any, Dict from urllib.parse import urlencode from werkzeug.routing import BaseConverter @@ -10,13 +10,13 @@ class RegexConverter(BaseConverter): part_isolating = False - def __init__(self, url_map, *items): + def __init__(self, url_map: Any, *items: Any): super().__init__(url_map) self.regex = items[0] class AWSTestHelper(FlaskClient): - def action_data(self, action_name, **kwargs): + def action_data(self, action_name: str, **kwargs: Any) -> str: """ Method calls resource with action_name and returns data of response. """ @@ -24,11 +24,11 @@ class AWSTestHelper(FlaskClient): opts.update(kwargs) res = self.get( f"/?{urlencode(opts)}", - headers={"Host": f"{self.application.service}.us-east-1.amazonaws.com"}, + headers={"Host": f"{self.application.service}.us-east-1.amazonaws.com"}, # type: ignore[attr-defined] ) return res.data.decode("utf-8") - def action_json(self, action_name, **kwargs): + def action_json(self, action_name: str, **kwargs: Any) -> Dict[str, Any]: """ Method calls resource with action_name and returns object obtained via deserialization of output. diff --git a/moto/moto_server/werkzeug_app.py b/moto/moto_server/werkzeug_app.py index ab07f3bc7..14f855174 100644 --- a/moto/moto_server/werkzeug_app.py +++ b/moto/moto_server/werkzeug_app.py @@ -2,6 +2,7 @@ import io import os import os.path from threading import Lock +from typing import Any, Callable, Dict, Optional, Tuple try: from flask import Flask @@ -49,20 +50,22 @@ SIGNING_ALIASES = { SERVICE_BY_VERSION = {"2009-04-15": "sdb"} -class DomainDispatcherApplication(object): +class DomainDispatcherApplication: """ Dispatch requests to different applications based on the "Host:" header value. We'll match the host header value with the url_bases of each backend. """ - def __init__(self, create_app, service=None): + def __init__( + self, create_app: Callable[[str], Flask], service: Optional[str] = None + ): self.create_app = create_app self.lock = Lock() - self.app_instances = {} + self.app_instances: Dict[str, Flask] = {} self.service = service self.backend_url_patterns = backend_index.backend_url_patterns - def get_backend_for_host(self, host): + def get_backend_for_host(self, host: str) -> Any: if host == "moto_api": return host @@ -83,7 +86,9 @@ class DomainDispatcherApplication(object): "Remember to add the URL to urls.py, and run scripts/update_backend_index.py to index it." ) - def infer_service_region_host(self, body, environ): + def infer_service_region_host( + self, body: Optional[str], environ: Dict[str, Any] + ) -> str: auth = environ.get("HTTP_AUTHORIZATION") target = environ.get("HTTP_X_AMZ_TARGET") service = None @@ -111,7 +116,7 @@ class DomainDispatcherApplication(object): service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION) elif action and action in UNSIGNED_ACTIONS: # See if we can match the Action to a known service - service, region = UNSIGNED_ACTIONS.get(action) + service, region = UNSIGNED_ACTIONS[action] if not service: service, region = self.get_service_from_body(body, environ) if not service: @@ -153,7 +158,7 @@ class DomainDispatcherApplication(object): return host - def get_application(self, environ): + def get_application(self, environ: Dict[str, Any]) -> Flask: path_info = environ.get("PATH_INFO", "") # The URL path might contain non-ASCII text, for instance unicode S3 bucket names @@ -181,7 +186,7 @@ class DomainDispatcherApplication(object): self.app_instances[backend] = app return app - def _get_body(self, environ): + def _get_body(self, environ: Dict[str, Any]) -> Optional[str]: body = None try: # AWS requests use querystrings as the body (Action=x&Data=y&...) @@ -190,7 +195,7 @@ class DomainDispatcherApplication(object): ) request_body_size = int(environ["CONTENT_LENGTH"]) if simple_form and request_body_size: - body = environ["wsgi.input"].read(request_body_size).decode("utf-8") + body = environ["wsgi.input"].read(request_body_size).decode("utf-8") # type: ignore except (KeyError, ValueError): pass finally: @@ -199,7 +204,9 @@ class DomainDispatcherApplication(object): environ["wsgi.input"] = io.StringIO(body) return body - def get_service_from_body(self, body, environ): + def get_service_from_body( + self, body: Optional[str], environ: Dict[str, Any] + ) -> Tuple[Optional[str], Optional[str]]: # Some services have the SDK Version in the body # If the version is unique, we can derive the service from it version = self.get_version_from_body(body) @@ -209,22 +216,24 @@ class DomainDispatcherApplication(object): return SERVICE_BY_VERSION[version], region return None, None - def get_version_from_body(self, body): + def get_version_from_body(self, body: Optional[str]) -> Optional[str]: try: - body_dict = dict(x.split("=") for x in body.split("&")) + body_dict = dict(x.split("=") for x in body.split("&")) # type: ignore return body_dict["Version"] except (AttributeError, KeyError, ValueError): return None - def get_action_from_body(self, body): + def get_action_from_body(self, body: Optional[str]) -> Optional[str]: try: # AWS requests use querystrings as the body (Action=x&Data=y&...) - body_dict = dict(x.split("=") for x in body.split("&")) + body_dict = dict(x.split("=") for x in body.split("&")) # type: ignore return body_dict["Action"] except (AttributeError, KeyError, ValueError): return None - def get_service_from_path(self, environ): + def get_service_from_path( + self, environ: Dict[str, Any] + ) -> Tuple[Optional[str], Optional[str]]: # Moto sometimes needs to send a HTTP request to itself # In which case it will send a request to 'http://localhost/service_region/whatever' try: @@ -234,12 +243,12 @@ class DomainDispatcherApplication(object): except (AttributeError, KeyError, ValueError): return None, None - def __call__(self, environ, start_response): + def __call__(self, environ: Dict[str, Any], start_response: Any) -> Any: backend_app = self.get_application(environ) return backend_app(environ, start_response) -def create_backend_app(service): +def create_backend_app(service: str) -> Flask: from werkzeug.routing import Map current_file = os.path.abspath(__file__) @@ -249,7 +258,7 @@ def create_backend_app(service): # Create the backend_app backend_app = Flask("moto", template_folder=template_dir) backend_app.debug = True - backend_app.service = service + backend_app.service = service # type: ignore[attr-defined] CORS(backend_app) # Reset view functions to reset the app diff --git a/moto/mq/exceptions.py b/moto/mq/exceptions.py index 0276178cd..d22dc58fd 100644 --- a/moto/mq/exceptions.py +++ b/moto/mq/exceptions.py @@ -1,4 +1,5 @@ import json +from typing import Any from moto.core.exceptions import JsonRESTError @@ -7,11 +8,13 @@ class MQError(JsonRESTError): class UnknownBroker(MQError): - def __init__(self, broker_id): + def __init__(self, broker_id: str): super().__init__("NotFoundException", "Can't find requested broker") self.broker_id = broker_id - def get_body(self, *args, **kwargs): # pylint: disable=unused-argument + def get_body( + self, *args: Any, **kwargs: Any + ) -> str: # pylint: disable=unused-argument body = { "errorAttribute": "broker-id", "message": f"Can't find requested broker [{self.broker_id}]. Make sure your broker exists.", @@ -20,11 +23,13 @@ class UnknownBroker(MQError): class UnknownConfiguration(MQError): - def __init__(self, config_id): + def __init__(self, config_id: str): super().__init__("NotFoundException", "Can't find requested configuration") self.config_id = config_id - def get_body(self, *args, **kwargs): # pylint: disable=unused-argument + def get_body( + self, *args: Any, **kwargs: Any + ) -> str: # pylint: disable=unused-argument body = { "errorAttribute": "configuration_id", "message": f"Can't find requested configuration [{self.config_id}]. Make sure your configuration exists.", @@ -33,11 +38,13 @@ class UnknownConfiguration(MQError): class UnknownUser(MQError): - def __init__(self, username): + def __init__(self, username: str): super().__init__("NotFoundException", "Can't find requested user") self.username = username - def get_body(self, *args, **kwargs): # pylint: disable=unused-argument + def get_body( + self, *args: Any, **kwargs: Any + ) -> str: # pylint: disable=unused-argument body = { "errorAttribute": "username", "message": f"Can't find requested user [{self.username}]. Make sure your user exists.", @@ -46,11 +53,13 @@ class UnknownUser(MQError): class UnsupportedEngineType(MQError): - def __init__(self, engine_type): + def __init__(self, engine_type: str): super().__init__("BadRequestException", "") self.engine_type = engine_type - def get_body(self, *args, **kwargs): # pylint: disable=unused-argument + def get_body( + self, *args: Any, **kwargs: Any + ) -> str: # pylint: disable=unused-argument body = { "errorAttribute": "engineType", "message": f"Broker engine type [{self.engine_type}] does not support configuration.", @@ -59,11 +68,13 @@ class UnsupportedEngineType(MQError): class UnknownEngineType(MQError): - def __init__(self, engine_type): + def __init__(self, engine_type: str): super().__init__("BadRequestException", "") self.engine_type = engine_type - def get_body(self, *args, **kwargs): # pylint: disable=unused-argument + def get_body( + self, *args: Any, **kwargs: Any + ) -> str: # pylint: disable=unused-argument body = { "errorAttribute": "engineType", "message": f"Broker engine type [{self.engine_type}] is invalid. Valid values are: [ACTIVEMQ]", diff --git a/moto/mq/models.py b/moto/mq/models.py index 29491d9cf..28d7113f5 100644 --- a/moto/mq/models.py +++ b/moto/mq/models.py @@ -1,5 +1,6 @@ import base64 import xmltodict +from typing import Any, Dict, List, Iterable, Optional, Tuple from moto.core import BaseBackend, BackendDict, BaseModel from moto.core.utils import unix_time @@ -17,7 +18,13 @@ from .exceptions import ( class ConfigurationRevision(BaseModel): - def __init__(self, configuration_id, revision_id, description, data=None): + def __init__( + self, + configuration_id: str, + revision_id: str, + description: str, + data: Optional[str] = None, + ): self.configuration_id = configuration_id self.created = unix_time() self.description = description @@ -31,7 +38,7 @@ class ConfigurationRevision(BaseModel): else: self.data = data - def has_ldap_auth(self): + def has_ldap_auth(self) -> bool: try: xml = base64.b64decode(self.data) dct = xmltodict.parse(xml, dict_constructor=dict) @@ -45,7 +52,7 @@ class ConfigurationRevision(BaseModel): # If anything fails, lets assume it's not LDAP return False - def to_json(self, full=True): + def to_json(self, full: bool = True) -> Dict[str, Any]: resp = { "created": self.created, "description": self.description, @@ -58,7 +65,14 @@ class ConfigurationRevision(BaseModel): class Configuration(BaseModel): - def __init__(self, account_id, region, name, engine_type, engine_version): + def __init__( + self, + account_id: str, + region: str, + name: str, + engine_type: str, + engine_version: str, + ): self.id = f"c-{mock_random.get_random_hex(6)}" self.arn = f"arn:aws:mq:{region}:{account_id}:configuration:{self.id}" self.created = unix_time() @@ -67,7 +81,7 @@ class Configuration(BaseModel): self.engine_type = engine_type self.engine_version = engine_version - self.revisions = dict() + self.revisions: Dict[str, ConfigurationRevision] = dict() default_desc = ( f"Auto-generated default for {self.name} on {engine_type} {engine_version}" ) @@ -80,7 +94,7 @@ class Configuration(BaseModel): "ldap" if latest_revision.has_ldap_auth() else "simple" ) - def update(self, data, description): + def update(self, data: str, description: str) -> None: max_revision_id, _ = sorted(self.revisions.items())[-1] next_revision_id = str(int(max_revision_id) + 1) latest_revision = ConfigurationRevision( @@ -95,10 +109,10 @@ class Configuration(BaseModel): "ldap" if latest_revision.has_ldap_auth() else "simple" ) - def get_revision(self, revision_id): + def get_revision(self, revision_id: str) -> ConfigurationRevision: return self.revisions[revision_id] - def to_json(self): + def to_json(self) -> Dict[str, Any]: _, latest_revision = sorted(self.revisions.items())[-1] return { "arn": self.arn, @@ -113,22 +127,30 @@ class Configuration(BaseModel): class User(BaseModel): - def __init__(self, broker_id, username, console_access=None, groups=None): + def __init__( + self, + broker_id: str, + username: str, + console_access: Optional[bool] = None, + groups: Optional[List[str]] = None, + ): self.broker_id = broker_id self.username = username self.console_access = console_access or False self.groups = groups or [] - def update(self, console_access, groups): + def update( + self, console_access: Optional[bool], groups: Optional[List[str]] + ) -> None: if console_access is not None: self.console_access = console_access if groups: self.groups = groups - def summary(self): + def summary(self) -> Dict[str, str]: return {"username": self.username} - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "brokerId": self.broker_id, "username": self.username, @@ -140,25 +162,25 @@ class User(BaseModel): class Broker(BaseModel): def __init__( self, - name, - account_id, - region, - authentication_strategy, - auto_minor_version_upgrade, - configuration, - deployment_mode, - encryption_options, - engine_type, - engine_version, - host_instance_type, - ldap_server_metadata, - logs, - maintenance_window_start_time, - publicly_accessible, - security_groups, - storage_type, - subnet_ids, - users, + name: str, + account_id: str, + region: str, + authentication_strategy: str, + auto_minor_version_upgrade: bool, + configuration: Dict[str, Any], + deployment_mode: str, + encryption_options: Dict[str, Any], + engine_type: str, + engine_version: str, + host_instance_type: str, + ldap_server_metadata: Dict[str, Any], + logs: Dict[str, bool], + maintenance_window_start_time: Dict[str, str], + publicly_accessible: bool, + security_groups: List[str], + storage_type: str, + subnet_ids: List[str], + users: List[Dict[str, Any]], ): self.name = name self.id = mock_random.get_random_hex(6) @@ -206,7 +228,7 @@ class Broker(BaseModel): else: self.subnet_ids = ["default-subnet"] - self.users = dict() + self.users: Dict[str, User] = dict() for user in users: self.create_user( username=user["username"], @@ -215,13 +237,13 @@ class Broker(BaseModel): ) if self.engine_type.upper() == "RABBITMQ": - self.configurations = None + self.configurations: Optional[Dict[str, Any]] = None else: current_config = configuration or { "id": f"c-{mock_random.get_random_hex(6)}", "revision": 1, } - self.configurations = { + self.configurations = { # type: ignore[no-redef] "current": current_config, "history": [], } @@ -256,23 +278,23 @@ class Broker(BaseModel): def update( self, - authentication_strategy, - auto_minor_version_upgrade, - configuration, - engine_version, - host_instance_type, - ldap_server_metadata, - logs, - maintenance_window_start_time, - security_groups, - ): + authentication_strategy: Optional[str], + auto_minor_version_upgrade: Optional[bool], + configuration: Optional[Dict[str, Any]], + engine_version: Optional[str], + host_instance_type: Optional[str], + ldap_server_metadata: Optional[Dict[str, Any]], + logs: Optional[Dict[str, bool]], + maintenance_window_start_time: Optional[Dict[str, str]], + security_groups: Optional[List[str]], + ) -> None: if authentication_strategy: self.authentication_strategy = authentication_strategy if auto_minor_version_upgrade is not None: self.auto_minor_version_upgrade = auto_minor_version_upgrade if configuration: - self.configurations["history"].append(self.configurations["current"]) - self.configurations["current"] = configuration + self.configurations["history"].append(self.configurations["current"]) # type: ignore[index] + self.configurations["current"] = configuration # type: ignore[index] if engine_version: self.engine_version = engine_version if host_instance_type: @@ -286,29 +308,33 @@ class Broker(BaseModel): if security_groups: self.security_groups = security_groups - def reboot(self): + def reboot(self) -> None: pass - def create_user(self, username, console_access, groups): + def create_user( + self, username: str, console_access: bool, groups: List[str] + ) -> None: user = User(self.id, username, console_access, groups) self.users[username] = user - def update_user(self, username, console_access, groups): + def update_user( + self, username: str, console_access: bool, groups: List[str] + ) -> None: user = self.get_user(username) user.update(console_access, groups) - def get_user(self, username): + def get_user(self, username: str) -> User: if username not in self.users: raise UnknownUser(username) return self.users[username] - def delete_user(self, username): + def delete_user(self, username: str) -> None: self.users.pop(username, None) - def list_users(self): + def list_users(self) -> Iterable[User]: return self.users.values() - def summary(self): + def summary(self) -> Dict[str, Any]: return { "brokerArn": self.arn, "brokerId": self.id, @@ -320,7 +346,7 @@ class Broker(BaseModel): "hostInstanceType": self.host_instance_type, } - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "brokerId": self.id, "brokerArn": self.arn, @@ -352,33 +378,33 @@ class MQBackend(BaseBackend): No EC2 integration exists yet - subnet ID's and security group values are not validated. Default values may not exist. """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.brokers = dict() - self.configs = dict() + self.brokers: Dict[str, Broker] = dict() + self.configs: Dict[str, Configuration] = dict() self.tagger = TaggingService() def create_broker( self, - authentication_strategy, - auto_minor_version_upgrade, - broker_name, - configuration, - deployment_mode, - encryption_options, - engine_type, - engine_version, - host_instance_type, - ldap_server_metadata, - logs, - maintenance_window_start_time, - publicly_accessible, - security_groups, - storage_type, - subnet_ids, - tags, - users, - ): + authentication_strategy: str, + auto_minor_version_upgrade: bool, + broker_name: str, + configuration: Dict[str, Any], + deployment_mode: str, + encryption_options: Dict[str, Any], + engine_type: str, + engine_version: str, + host_instance_type: str, + ldap_server_metadata: Dict[str, Any], + logs: Dict[str, bool], + maintenance_window_start_time: Dict[str, str], + publicly_accessible: bool, + security_groups: List[str], + storage_type: str, + subnet_ids: List[str], + tags: Dict[str, str], + users: List[Dict[str, Any]], + ) -> Tuple[str, str]: broker = Broker( name=broker_name, account_id=self.account_id, @@ -404,44 +430,50 @@ class MQBackend(BaseBackend): self.create_tags(broker.arn, tags) return broker.arn, broker.id - def delete_broker(self, broker_id): + def delete_broker(self, broker_id: str) -> None: del self.brokers[broker_id] - def describe_broker(self, broker_id): + def describe_broker(self, broker_id: str) -> Broker: if broker_id not in self.brokers: raise UnknownBroker(broker_id) return self.brokers[broker_id] - def reboot_broker(self, broker_id): + def reboot_broker(self, broker_id: str) -> None: self.brokers[broker_id].reboot() - def list_brokers(self): + def list_brokers(self) -> Iterable[Broker]: """ Pagination is not yet implemented """ return self.brokers.values() - def create_user(self, broker_id, username, console_access, groups): + def create_user( + self, broker_id: str, username: str, console_access: bool, groups: List[str] + ) -> None: broker = self.describe_broker(broker_id) broker.create_user(username, console_access, groups) - def update_user(self, broker_id, console_access, groups, username): + def update_user( + self, broker_id: str, console_access: bool, groups: List[str], username: str + ) -> None: broker = self.describe_broker(broker_id) broker.update_user(username, console_access, groups) - def describe_user(self, broker_id, username): + def describe_user(self, broker_id: str, username: str) -> User: broker = self.describe_broker(broker_id) return broker.get_user(username) - def delete_user(self, broker_id, username): + def delete_user(self, broker_id: str, username: str) -> None: broker = self.describe_broker(broker_id) broker.delete_user(username) - def list_users(self, broker_id): + def list_users(self, broker_id: str) -> Iterable[User]: broker = self.describe_broker(broker_id) return broker.list_users() - def create_configuration(self, name, engine_type, engine_version, tags): + def create_configuration( + self, name: str, engine_type: str, engine_version: str, tags: Dict[str, str] + ) -> Configuration: if engine_type.upper() == "RABBITMQ": raise UnsupportedEngineType(engine_type) if engine_type.upper() != "ACTIVEMQ": @@ -459,7 +491,9 @@ class MQBackend(BaseBackend): ) return config - def update_configuration(self, config_id, data, description): + def update_configuration( + self, config_id: str, data: str, description: str + ) -> Configuration: """ No validation occurs on the provided XML. The authenticationStrategy may be changed depending on the provided configuration. """ @@ -467,47 +501,49 @@ class MQBackend(BaseBackend): config.update(data, description) return config - def describe_configuration(self, config_id): + def describe_configuration(self, config_id: str) -> Configuration: if config_id not in self.configs: raise UnknownConfiguration(config_id) return self.configs[config_id] - def describe_configuration_revision(self, config_id, revision_id): + def describe_configuration_revision( + self, config_id: str, revision_id: str + ) -> ConfigurationRevision: config = self.configs[config_id] return config.get_revision(revision_id) - def list_configurations(self): + def list_configurations(self) -> Iterable[Configuration]: """ Pagination has not yet been implemented. """ return self.configs.values() - def create_tags(self, resource_arn, tags): + def create_tags(self, resource_arn: str, tags: Dict[str, str]) -> None: self.tagger.tag_resource( resource_arn, self.tagger.convert_dict_to_tags_input(tags) ) - def list_tags(self, arn): + def list_tags(self, arn: str) -> Dict[str, str]: return self.tagger.get_tag_dict_for_resource(arn) - def delete_tags(self, resource_arn, tag_keys): + def delete_tags(self, resource_arn: str, tag_keys: List[str]) -> None: if not isinstance(tag_keys, list): tag_keys = [tag_keys] self.tagger.untag_resource_using_names(resource_arn, tag_keys) def update_broker( self, - authentication_strategy, - auto_minor_version_upgrade, - broker_id, - configuration, - engine_version, - host_instance_type, - ldap_server_metadata, - logs, - maintenance_window_start_time, - security_groups, - ): + authentication_strategy: str, + auto_minor_version_upgrade: bool, + broker_id: str, + configuration: Dict[str, Any], + engine_version: str, + host_instance_type: str, + ldap_server_metadata: Dict[str, Any], + logs: Dict[str, bool], + maintenance_window_start_time: Dict[str, str], + security_groups: List[str], + ) -> None: broker = self.describe_broker(broker_id) broker.update( authentication_strategy=authentication_strategy, diff --git a/moto/mq/responses.py b/moto/mq/responses.py index bee19d039..e4d1c5599 100644 --- a/moto/mq/responses.py +++ b/moto/mq/responses.py @@ -1,23 +1,25 @@ """Handles incoming mq requests, invokes methods, returns responses.""" import json +from typing import Any from urllib.parse import unquote +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse -from .models import mq_backends +from .models import mq_backends, MQBackend class MQResponse(BaseResponse): """Handler for MQ requests and responses.""" - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="mq") @property - def mq_backend(self): + def mq_backend(self) -> MQBackend: """Return backend instance specific for this region.""" return mq_backends[self.current_account][self.region] - def broker(self, request, full_url, headers): + def broker(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.describe_broker() @@ -26,40 +28,40 @@ class MQResponse(BaseResponse): if request.method == "PUT": return self.update_broker() - def brokers(self, request, full_url, headers): + def brokers(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": return self.create_broker() if request.method == "GET": return self.list_brokers() - def configuration(self, request, full_url, headers): + def configuration(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.describe_configuration() if request.method == "PUT": return self.update_configuration() - def configurations(self, request, full_url, headers): + def configurations(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": return self.create_configuration() if request.method == "GET": return self.list_configurations() - def configuration_revision(self, request, full_url, headers): + def configuration_revision(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.get_configuration_revision() - def tags(self, request, full_url, headers): + def tags(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": return self.create_tags() if request.method == "DELETE": return self.delete_tags() - def user(self, request, full_url, headers): + def user(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": return self.create_user() @@ -70,12 +72,12 @@ class MQResponse(BaseResponse): if request.method == "DELETE": return self.delete_user() - def users(self, request, full_url, headers): + def users(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.list_users() - def create_broker(self): + def create_broker(self) -> TYPE_RESPONSE: params = json.loads(self.body) authentication_strategy = params.get("authenticationStrategy") auto_minor_version_upgrade = params.get("autoMinorVersionUpgrade") @@ -119,7 +121,7 @@ class MQResponse(BaseResponse): resp = {"brokerArn": broker_arn, "brokerId": broker_id} return 200, {}, json.dumps(resp) - def update_broker(self): + def update_broker(self) -> TYPE_RESPONSE: params = json.loads(self.body) broker_id = self.path.split("/")[-1] authentication_strategy = params.get("authenticationStrategy") @@ -145,23 +147,23 @@ class MQResponse(BaseResponse): ) return self.describe_broker() - def delete_broker(self): + def delete_broker(self) -> TYPE_RESPONSE: broker_id = self.path.split("/")[-1] self.mq_backend.delete_broker(broker_id=broker_id) return 200, {}, json.dumps(dict(brokerId=broker_id)) - def describe_broker(self): + def describe_broker(self) -> TYPE_RESPONSE: broker_id = self.path.split("/")[-1] broker = self.mq_backend.describe_broker(broker_id=broker_id) resp = broker.to_json() resp["tags"] = self.mq_backend.list_tags(broker.arn) return 200, {}, json.dumps(resp) - def list_brokers(self): + def list_brokers(self) -> TYPE_RESPONSE: brokers = self.mq_backend.list_brokers() return 200, {}, json.dumps(dict(brokerSummaries=[b.summary() for b in brokers])) - def create_user(self): + def create_user(self) -> TYPE_RESPONSE: params = json.loads(self.body) broker_id = self.path.split("/")[-3] username = self.path.split("/")[-1] @@ -170,7 +172,7 @@ class MQResponse(BaseResponse): self.mq_backend.create_user(broker_id, username, console_access, groups) return 200, {}, "{}" - def update_user(self): + def update_user(self) -> TYPE_RESPONSE: params = json.loads(self.body) broker_id = self.path.split("/")[-3] username = self.path.split("/")[-1] @@ -184,19 +186,19 @@ class MQResponse(BaseResponse): ) return 200, {}, "{}" - def describe_user(self): + def describe_user(self) -> TYPE_RESPONSE: broker_id = self.path.split("/")[-3] username = self.path.split("/")[-1] user = self.mq_backend.describe_user(broker_id, username) return 200, {}, json.dumps(user.to_json()) - def delete_user(self): + def delete_user(self) -> TYPE_RESPONSE: broker_id = self.path.split("/")[-3] username = self.path.split("/")[-1] self.mq_backend.delete_user(broker_id, username) return 200, {}, "{}" - def list_users(self): + def list_users(self) -> TYPE_RESPONSE: broker_id = self.path.split("/")[-2] users = self.mq_backend.list_users(broker_id=broker_id) resp = { @@ -205,7 +207,7 @@ class MQResponse(BaseResponse): } return 200, {}, json.dumps(resp) - def create_configuration(self): + def create_configuration(self) -> TYPE_RESPONSE: params = json.loads(self.body) name = params.get("name") engine_type = params.get("engineType") @@ -217,19 +219,19 @@ class MQResponse(BaseResponse): ) return 200, {}, json.dumps(config.to_json()) - def describe_configuration(self): + def describe_configuration(self) -> TYPE_RESPONSE: config_id = self.path.split("/")[-1] config = self.mq_backend.describe_configuration(config_id) resp = config.to_json() resp["tags"] = self.mq_backend.list_tags(config.arn) return 200, {}, json.dumps(resp) - def list_configurations(self): + def list_configurations(self) -> TYPE_RESPONSE: configs = self.mq_backend.list_configurations() resp = {"configurations": [c.to_json() for c in configs]} return 200, {}, json.dumps(resp) - def update_configuration(self): + def update_configuration(self) -> TYPE_RESPONSE: config_id = self.path.split("/")[-1] params = json.loads(self.body) data = params.get("data") @@ -237,7 +239,7 @@ class MQResponse(BaseResponse): config = self.mq_backend.update_configuration(config_id, data, description) return 200, {}, json.dumps(config.to_json()) - def get_configuration_revision(self): + def get_configuration_revision(self) -> TYPE_RESPONSE: revision_id = self.path.split("/")[-1] config_id = self.path.split("/")[-3] revision = self.mq_backend.describe_configuration_revision( @@ -245,19 +247,19 @@ class MQResponse(BaseResponse): ) return 200, {}, json.dumps(revision.to_json()) - def create_tags(self): + def create_tags(self) -> TYPE_RESPONSE: resource_arn = unquote(self.path.split("/")[-1]) tags = json.loads(self.body).get("tags", {}) self.mq_backend.create_tags(resource_arn, tags) return 200, {}, "{}" - def delete_tags(self): + def delete_tags(self) -> TYPE_RESPONSE: resource_arn = unquote(self.path.split("/")[-1]) tag_keys = self._get_param("tagKeys") self.mq_backend.delete_tags(resource_arn, tag_keys) return 200, {}, "{}" - def reboot(self, request, full_url, headers): + def reboot(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": broker_id = self.path.split("/")[-2] diff --git a/setup.cfg b/setup.cfg index e57eaf2c4..c91bbe3c9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -235,7 +235,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/managedblockchain,moto/moto_api,moto/neptune,moto/opensearch,moto/rdsdata +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/neptune,moto/opensearch,moto/rdsdata show_column_numbers=True show_error_codes = True disable_error_code=abstract diff --git a/tests/test_mediaconnect/test_mediaconnect.py b/tests/test_mediaconnect/test_mediaconnect.py index 3f01ecdd4..bb60e3e62 100644 --- a/tests/test_mediaconnect/test_mediaconnect.py +++ b/tests/test_mediaconnect/test_mediaconnect.py @@ -205,6 +205,31 @@ def test_start_stop_flow_succeeds(): describe_response["Flow"]["Status"].should.equal("STANDBY") +@mock_mediaconnect +def test_unknown_flow(): + client = boto3.client("mediaconnect", region_name=region) + + with pytest.raises(ClientError) as exc: + client.describe_flow(FlowArn="unknown") + assert exc.value.response["Error"]["Code"] == "NotFoundException" + + with pytest.raises(ClientError) as exc: + client.delete_flow(FlowArn="unknown") + assert exc.value.response["Error"]["Code"] == "NotFoundException" + + with pytest.raises(ClientError) as exc: + client.start_flow(FlowArn="unknown") + assert exc.value.response["Error"]["Code"] == "NotFoundException" + + with pytest.raises(ClientError) as exc: + client.stop_flow(FlowArn="unknown") + assert exc.value.response["Error"]["Code"] == "NotFoundException" + + with pytest.raises(ClientError) as exc: + client.list_tags_for_resource(ResourceArn="unknown") + assert exc.value.response["Error"]["Code"] == "NotFoundException" + + @mock_mediaconnect def test_tag_resource_succeeds(): client = boto3.client("mediaconnect", region_name=region)