Techdebt: MyPy M (#6170)

This commit is contained in:
Bert Blommers 2023-04-03 23:50:19 +01:00 committed by GitHub
parent 52870d6114
commit 706ff9f5e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 752 additions and 746 deletions

View File

@ -4,5 +4,5 @@ from moto.core.exceptions import JsonRESTError
class NotFoundException(JsonRESTError): class NotFoundException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("NotFoundException", message) super().__init__("NotFoundException", message)

View File

@ -1,12 +1,15 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.mediaconnect.exceptions import NotFoundException from moto.mediaconnect.exceptions import NotFoundException
from moto.moto_api._internal import mock_random as random from moto.moto_api._internal import mock_random as random
from moto.utilities.tagging_service import TaggingService
class Flow(BaseModel): class Flow(BaseModel):
def __init__(self, **kwargs): def __init__(self, account_id: str, region_name: str, **kwargs: Any):
self.id = random.uuid4().hex
self.availability_zone = kwargs.get("availability_zone") self.availability_zone = kwargs.get("availability_zone")
self.entitlements = kwargs.get("entitlements", []) self.entitlements = kwargs.get("entitlements", [])
self.name = kwargs.get("name") self.name = kwargs.get("name")
@ -15,17 +18,19 @@ class Flow(BaseModel):
self.source_failover_config = kwargs.get("source_failover_config", {}) self.source_failover_config = kwargs.get("source_failover_config", {})
self.sources = kwargs.get("sources", []) self.sources = kwargs.get("sources", [])
self.vpc_interfaces = kwargs.get("vpc_interfaces", []) self.vpc_interfaces = kwargs.get("vpc_interfaces", [])
self.status = "STANDBY" # one of 'STANDBY'|'ACTIVE'|'UPDATING'|'DELETING'|'STARTING'|'STOPPING'|'ERROR' self.status: Optional[
self._previous_status = None str
self.description = None ] = "STANDBY" # one of 'STANDBY'|'ACTIVE'|'UPDATING'|'DELETING'|'STARTING'|'STOPPING'|'ERROR'
self.flow_arn = None self._previous_status: Optional[str] = None
self.egress_ip = None self.description = "A Moto test flow"
self.flow_arn = f"arn:aws:mediaconnect:{region_name}:{account_id}:flow:{self.id}:{self.name}"
self.egress_ip = "127.0.0.1"
if self.source and not self.sources: if self.source and not self.sources:
self.sources = [ self.sources = [
self.source, self.source,
] ]
def to_dict(self, include=None): def to_dict(self, include: Optional[List[str]] = None) -> Dict[str, Any]:
data = { data = {
"availabilityZone": self.availability_zone, "availabilityZone": self.availability_zone,
"description": self.description, "description": self.description,
@ -47,7 +52,7 @@ class Flow(BaseModel):
return new_data return new_data
return data return data
def resolve_transient_states(self): def resolve_transient_states(self) -> None:
if self.status in ["STARTING"]: if self.status in ["STARTING"]:
self.status = "ACTIVE" self.status = "ACTIVE"
if self.status in ["STOPPING"]: if self.status in ["STOPPING"]:
@ -57,26 +62,18 @@ class Flow(BaseModel):
self._previous_status = None self._previous_status = None
class Resource(BaseModel):
def __init__(self, **kwargs):
self.resource_arn = kwargs.get("resource_arn")
self.tags = OrderedDict()
def to_dict(self):
data = {
"resourceArn": self.resource_arn,
"tags": self.tags,
}
return data
class MediaConnectBackend(BaseBackend): class MediaConnectBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self._flows = OrderedDict() self._flows: Dict[str, Flow] = OrderedDict()
self._resources = OrderedDict() self.tagger = TaggingService()
def _add_source_details(self, source, flow_id, ingest_ip="127.0.0.1"): def _add_source_details(
self,
source: Optional[Dict[str, Any]],
flow_id: str,
ingest_ip: str = "127.0.0.1",
) -> None:
if source: if source:
source["sourceArn"] = ( source["sourceArn"] = (
f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source" f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source"
@ -85,7 +82,9 @@ class MediaConnectBackend(BaseBackend):
if not source.get("entitlementArn"): if not source.get("entitlementArn"):
source["ingestIp"] = ingest_ip source["ingestIp"] = ingest_ip
def _add_entitlement_details(self, entitlement, entitlement_id): def _add_entitlement_details(
self, entitlement: Optional[Dict[str, Any]], entitlement_id: str
) -> None:
if entitlement: if entitlement:
entitlement["entitlementArn"] = ( entitlement["entitlementArn"] = (
f"arn:aws:mediaconnect:{self.region_name}" f"arn:aws:mediaconnect:{self.region_name}"
@ -93,15 +92,9 @@ class MediaConnectBackend(BaseBackend):
f":{entitlement['name']}" f":{entitlement['name']}"
) )
def _create_flow_add_details(self, flow): def _create_flow_add_details(self, flow: Flow) -> None:
flow_id = random.uuid4().hex
flow.description = "A Moto test flow"
flow.egress_ip = "127.0.0.1"
flow.flow_arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:flow:{flow_id}:{flow.name}"
for index, _source in enumerate(flow.sources): for index, _source in enumerate(flow.sources):
self._add_source_details(_source, flow_id, f"127.0.0.{index}") self._add_source_details(_source, flow.id, f"127.0.0.{index}")
for index, output in enumerate(flow.outputs or []): for index, output in enumerate(flow.outputs or []):
if output.get("protocol") in ["srt-listener", "zixi-pull"]: if output.get("protocol") in ["srt-listener", "zixi-pull"]:
@ -119,16 +112,18 @@ class MediaConnectBackend(BaseBackend):
def create_flow( def create_flow(
self, self,
availability_zone, availability_zone: str,
entitlements, entitlements: List[Dict[str, Any]],
name, name: str,
outputs, outputs: List[Dict[str, Any]],
source, source: Dict[str, Any],
source_failover_config, source_failover_config: Dict[str, Any],
sources, sources: List[Dict[str, Any]],
vpc_interfaces, vpc_interfaces: List[Dict[str, Any]],
): ) -> Flow:
flow = Flow( flow = Flow(
account_id=self.account_id,
region_name=self.region_name,
availability_zone=availability_zone, availability_zone=availability_zone,
entitlements=entitlements, entitlements=entitlements,
name=name, name=name,
@ -142,11 +137,14 @@ class MediaConnectBackend(BaseBackend):
self._flows[flow.flow_arn] = flow self._flows[flow.flow_arn] = flow
return flow return flow
def list_flows(self, max_results, next_token): def list_flows(self, max_results: Optional[int]) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
flows = list(self._flows.values()) flows = list(self._flows.values())
if max_results is not None: if max_results is not None:
flows = flows[:max_results] flows = flows[:max_results]
response_flows = [ return [
fl.to_dict( fl.to_dict(
include=[ include=[
"availabilityZone", "availabilityZone",
@ -159,74 +157,59 @@ class MediaConnectBackend(BaseBackend):
) )
for fl in flows for fl in flows
] ]
return response_flows, next_token
def describe_flow(self, flow_arn=None): def describe_flow(self, flow_arn: str) -> Flow:
messages = {}
if flow_arn in self._flows: if flow_arn in self._flows:
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
flow.resolve_transient_states() flow.resolve_transient_states()
else: return flow
raise NotFoundException(message="Flow not found.") raise NotFoundException(message="Flow not found.")
return flow.to_dict(), messages
def delete_flow(self, flow_arn): def delete_flow(self, flow_arn: str) -> Flow:
if flow_arn in self._flows: if flow_arn in self._flows:
flow = self._flows[flow_arn] return self._flows.pop(flow_arn)
del self._flows[flow_arn]
else:
raise NotFoundException(message="Flow not found.") raise NotFoundException(message="Flow not found.")
return flow_arn, flow.status
def start_flow(self, flow_arn): def start_flow(self, flow_arn: str) -> Flow:
if flow_arn in self._flows: if flow_arn in self._flows:
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
flow.status = "STARTING" flow.status = "STARTING"
else: return flow
raise NotFoundException(message="Flow not found.") raise NotFoundException(message="Flow not found.")
return flow_arn, flow.status
def stop_flow(self, flow_arn): def stop_flow(self, flow_arn: str) -> Flow:
if flow_arn in self._flows: if flow_arn in self._flows:
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
flow.status = "STOPPING" flow.status = "STOPPING"
else: return flow
raise NotFoundException(message="Flow not found.") raise NotFoundException(message="Flow not found.")
return flow_arn, flow.status
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: Dict[str, Any]) -> None:
if resource_arn in self._resources: tag_list = TaggingService.convert_dict_to_tags_input(tags)
resource = self._resources[resource_arn] self.tagger.tag_resource(resource_arn, tag_list)
else:
resource = Resource(resource_arn=resource_arn)
resource.tags.update(tags)
self._resources[resource_arn] = resource
return None
def list_tags_for_resource(self, resource_arn): def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]:
if resource_arn in self._resources: if self.tagger.has_tags(resource_arn):
resource = self._resources[resource_arn] return self.tagger.get_tag_dict_for_resource(resource_arn)
else:
raise NotFoundException(message="Resource not found.") raise NotFoundException(message="Resource not found.")
return resource.tags
def add_flow_vpc_interfaces(self, flow_arn, vpc_interfaces): def add_flow_vpc_interfaces(
self, flow_arn: str, vpc_interfaces: List[Dict[str, Any]]
) -> Flow:
if flow_arn in self._flows: if flow_arn in self._flows:
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
flow.vpc_interfaces = vpc_interfaces flow.vpc_interfaces = vpc_interfaces
else: return flow
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
return flow_arn, flow.vpc_interfaces
def add_flow_outputs(self, flow_arn, outputs): def add_flow_outputs(self, flow_arn: str, outputs: List[Dict[str, Any]]) -> Flow:
if flow_arn in self._flows: if flow_arn in self._flows:
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
flow.outputs = outputs flow.outputs = outputs
else: return flow
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
return flow_arn, flow.outputs
def remove_flow_vpc_interface(self, flow_arn, vpc_interface_name): def remove_flow_vpc_interface(self, flow_arn: str, vpc_interface_name: str) -> None:
if flow_arn in self._flows: if flow_arn in self._flows:
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
flow.vpc_interfaces = [ flow.vpc_interfaces = [
@ -236,9 +219,8 @@ class MediaConnectBackend(BaseBackend):
] ]
else: else:
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
return flow_arn, vpc_interface_name
def remove_flow_output(self, flow_arn, output_name): def remove_flow_output(self, flow_arn: str, output_name: str) -> None:
if flow_arn in self._flows: if flow_arn in self._flows:
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
flow.outputs = [ flow.outputs = [
@ -248,28 +230,27 @@ class MediaConnectBackend(BaseBackend):
] ]
else: else:
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
return flow_arn, output_name
def update_flow_output( def update_flow_output(
self, self,
flow_arn, flow_arn: str,
output_arn, output_arn: str,
cidr_allow_list, cidr_allow_list: List[str],
description, description: str,
destination, destination: str,
encryption, encryption: Dict[str, str],
max_latency, max_latency: int,
media_stream_output_configuration, media_stream_output_configuration: List[Dict[str, Any]],
min_latency, min_latency: int,
port, port: int,
protocol, protocol: str,
remote_id, remote_id: str,
sender_control_port, sender_control_port: int,
sender_ip_address, sender_ip_address: str,
smoothing_latency, smoothing_latency: int,
stream_id, stream_id: str,
vpc_interface_attachment, vpc_interface_attachment: Dict[str, str],
): ) -> Dict[str, Any]:
if flow_arn not in self._flows: if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
@ -292,10 +273,12 @@ class MediaConnectBackend(BaseBackend):
output["smoothingLatency"] = smoothing_latency output["smoothingLatency"] = smoothing_latency
output["streamId"] = stream_id output["streamId"] = stream_id
output["vpcInterfaceAttachment"] = vpc_interface_attachment output["vpcInterfaceAttachment"] = vpc_interface_attachment
return flow_arn, output return output
raise NotFoundException(message=f"output with arn={output_arn} not found") raise NotFoundException(message=f"output with arn={output_arn} not found")
def add_flow_sources(self, flow_arn, sources): def add_flow_sources(
self, flow_arn: str, sources: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
if flow_arn not in self._flows: if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
@ -305,32 +288,32 @@ class MediaConnectBackend(BaseBackend):
arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source:{source_id}:{name}" arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source:{source_id}:{name}"
source["sourceArn"] = arn source["sourceArn"] = arn
flow.sources = sources flow.sources = sources
return flow_arn, sources return sources
def update_flow_source( def update_flow_source(
self, self,
flow_arn, flow_arn: str,
source_arn, source_arn: str,
decryption, decryption: str,
description, description: str,
entitlement_arn, entitlement_arn: str,
ingest_port, ingest_port: int,
max_bitrate, max_bitrate: int,
max_latency, max_latency: int,
max_sync_buffer, max_sync_buffer: int,
media_stream_source_configurations, media_stream_source_configurations: List[Dict[str, Any]],
min_latency, min_latency: int,
protocol, protocol: str,
sender_control_port, sender_control_port: int,
sender_ip_address, sender_ip_address: str,
stream_id, stream_id: str,
vpc_interface_name, vpc_interface_name: str,
whitelist_cidr, whitelist_cidr: str,
): ) -> Optional[Dict[str, Any]]:
if flow_arn not in self._flows: if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
source = next( source: Optional[Dict[str, Any]] = next(
iter( iter(
[source for source in flow.sources if source["sourceArn"] == source_arn] [source for source in flow.sources if source["sourceArn"] == source_arn]
), ),
@ -354,13 +337,13 @@ class MediaConnectBackend(BaseBackend):
source["streamId"] = stream_id source["streamId"] = stream_id
source["vpcInterfaceName"] = vpc_interface_name source["vpcInterfaceName"] = vpc_interface_name
source["whitelistCidr"] = whitelist_cidr source["whitelistCidr"] = whitelist_cidr
return flow_arn, source return source
def grant_flow_entitlements( def grant_flow_entitlements(
self, self,
flow_arn, flow_arn: str,
entitlements, entitlements: List[Dict[str, Any]],
): ) -> List[Dict[str, Any]]:
if flow_arn not in self._flows: if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
@ -371,30 +354,30 @@ class MediaConnectBackend(BaseBackend):
entitlement["entitlementArn"] = arn entitlement["entitlementArn"] = arn
flow.entitlements += entitlements flow.entitlements += entitlements
return flow_arn, entitlements return entitlements
def revoke_flow_entitlement(self, flow_arn, entitlement_arn): def revoke_flow_entitlement(self, flow_arn: str, entitlement_arn: str) -> None:
if flow_arn not in self._flows: if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
for entitlement in flow.entitlements: for entitlement in flow.entitlements:
if entitlement_arn == entitlement["entitlementArn"]: if entitlement_arn == entitlement["entitlementArn"]:
flow.entitlements.remove(entitlement) flow.entitlements.remove(entitlement)
return flow_arn, entitlement_arn return
raise NotFoundException( raise NotFoundException(
message=f"entitlement with arn={entitlement_arn} not found" message=f"entitlement with arn={entitlement_arn} not found"
) )
def update_flow_entitlement( def update_flow_entitlement(
self, self,
flow_arn, flow_arn: str,
entitlement_arn, entitlement_arn: str,
description, description: str,
encryption, encryption: Dict[str, str],
entitlement_status, entitlement_status: str,
name, name: str,
subscribers, subscribers: List[str],
): ) -> Dict[str, Any]:
if flow_arn not in self._flows: if flow_arn not in self._flows:
raise NotFoundException(message=f"flow with arn={flow_arn} not found") raise NotFoundException(message=f"flow with arn={flow_arn} not found")
flow = self._flows[flow_arn] flow = self._flows[flow_arn]
@ -405,12 +388,10 @@ class MediaConnectBackend(BaseBackend):
entitlement["entitlementStatus"] = entitlement_status entitlement["entitlementStatus"] = entitlement_status
entitlement["name"] = name entitlement["name"] = name
entitlement["subscribers"] = subscribers entitlement["subscribers"] = subscribers
return flow_arn, entitlement return entitlement
raise NotFoundException( raise NotFoundException(
message=f"entitlement with arn={entitlement_arn} not found" message=f"entitlement with arn={entitlement_arn} not found"
) )
# add methods from here
mediaconnect_backends = BackendDict(MediaConnectBackend, "mediaconnect") mediaconnect_backends = BackendDict(MediaConnectBackend, "mediaconnect")

View File

@ -1,20 +1,20 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import mediaconnect_backends from .models import mediaconnect_backends, MediaConnectBackend
from urllib.parse import unquote from urllib.parse import unquote
class MediaConnectResponse(BaseResponse): class MediaConnectResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="mediaconnect") super().__init__(service_name="mediaconnect")
@property @property
def mediaconnect_backend(self): def mediaconnect_backend(self) -> MediaConnectBackend:
return mediaconnect_backends[self.current_account][self.region] return mediaconnect_backends[self.current_account][self.region]
def create_flow(self): def create_flow(self) -> str:
availability_zone = self._get_param("availabilityZone") availability_zone = self._get_param("availabilityZone")
entitlements = self._get_param("entitlements") entitlements = self._get_param("entitlements")
name = self._get_param("name") name = self._get_param("name")
@ -35,85 +35,79 @@ class MediaConnectResponse(BaseResponse):
) )
return json.dumps(dict(flow=flow.to_dict())) return json.dumps(dict(flow=flow.to_dict()))
def list_flows(self): def list_flows(self) -> str:
max_results = self._get_int_param("maxResults") max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken") flows = self.mediaconnect_backend.list_flows(max_results=max_results)
flows, next_token = self.mediaconnect_backend.list_flows( return json.dumps(dict(flows=flows))
max_results=max_results, next_token=next_token
)
return json.dumps(dict(flows=flows, nextToken=next_token))
def describe_flow(self): def describe_flow(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
flow, messages = self.mediaconnect_backend.describe_flow(flow_arn=flow_arn) flow = self.mediaconnect_backend.describe_flow(flow_arn=flow_arn)
return json.dumps(dict(flow=flow, messages=messages)) return json.dumps(dict(flow=flow.to_dict()))
def delete_flow(self): def delete_flow(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
flow_arn, status = self.mediaconnect_backend.delete_flow(flow_arn=flow_arn) flow = self.mediaconnect_backend.delete_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow_arn, status=status)) return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status))
def start_flow(self): def start_flow(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
flow_arn, status = self.mediaconnect_backend.start_flow(flow_arn=flow_arn) flow = self.mediaconnect_backend.start_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow_arn, status=status)) return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status))
def stop_flow(self): def stop_flow(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
flow_arn, status = self.mediaconnect_backend.stop_flow(flow_arn=flow_arn) flow = self.mediaconnect_backend.stop_flow(flow_arn=flow_arn)
return json.dumps(dict(flowArn=flow_arn, status=status)) return json.dumps(dict(flowArn=flow.flow_arn, status=flow.status))
def tag_resource(self): def tag_resource(self) -> str:
resource_arn = unquote(self._get_param("resourceArn")) resource_arn = unquote(self._get_param("resourceArn"))
tags = self._get_param("tags") tags = self._get_param("tags")
self.mediaconnect_backend.tag_resource(resource_arn=resource_arn, tags=tags) self.mediaconnect_backend.tag_resource(resource_arn=resource_arn, tags=tags)
return json.dumps(dict()) return json.dumps(dict())
def list_tags_for_resource(self): def list_tags_for_resource(self) -> str:
resource_arn = unquote(self._get_param("resourceArn")) resource_arn = unquote(self._get_param("resourceArn"))
tags = self.mediaconnect_backend.list_tags_for_resource( tags = self.mediaconnect_backend.list_tags_for_resource(resource_arn)
resource_arn=resource_arn
)
return json.dumps(dict(tags=tags)) return json.dumps(dict(tags=tags))
def add_flow_vpc_interfaces(self): def add_flow_vpc_interfaces(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
vpc_interfaces = self._get_param("vpcInterfaces") vpc_interfaces = self._get_param("vpcInterfaces")
flow_arn, vpc_interfaces = self.mediaconnect_backend.add_flow_vpc_interfaces( flow = self.mediaconnect_backend.add_flow_vpc_interfaces(
flow_arn=flow_arn, vpc_interfaces=vpc_interfaces flow_arn=flow_arn, vpc_interfaces=vpc_interfaces
) )
return json.dumps(dict(flow_arn=flow_arn, vpc_interfaces=vpc_interfaces)) return json.dumps(
dict(flow_arn=flow.flow_arn, vpc_interfaces=flow.vpc_interfaces)
)
def remove_flow_vpc_interface(self): def remove_flow_vpc_interface(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
vpc_interface_name = unquote(self._get_param("vpcInterfaceName")) vpc_interface_name = unquote(self._get_param("vpcInterfaceName"))
( self.mediaconnect_backend.remove_flow_vpc_interface(
flow_arn,
vpc_interface_name,
) = self.mediaconnect_backend.remove_flow_vpc_interface(
flow_arn=flow_arn, vpc_interface_name=vpc_interface_name flow_arn=flow_arn, vpc_interface_name=vpc_interface_name
) )
return json.dumps( return json.dumps(
dict(flow_arn=flow_arn, vpc_interface_name=vpc_interface_name) dict(flow_arn=flow_arn, vpc_interface_name=vpc_interface_name)
) )
def add_flow_outputs(self): def add_flow_outputs(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
outputs = self._get_param("outputs") outputs = self._get_param("outputs")
flow_arn, outputs = self.mediaconnect_backend.add_flow_outputs( flow = self.mediaconnect_backend.add_flow_outputs(
flow_arn=flow_arn, outputs=outputs flow_arn=flow_arn, outputs=outputs
) )
return json.dumps(dict(flow_arn=flow_arn, outputs=outputs)) return json.dumps(dict(flow_arn=flow.flow_arn, outputs=flow.outputs))
def remove_flow_output(self): def remove_flow_output(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
output_name = unquote(self._get_param("outputArn")) output_name = unquote(self._get_param("outputArn"))
flow_arn, output_name = self.mediaconnect_backend.remove_flow_output( self.mediaconnect_backend.remove_flow_output(
flow_arn=flow_arn, output_name=output_name flow_arn=flow_arn, output_name=output_name
) )
return json.dumps(dict(flow_arn=flow_arn, output_name=output_name)) return json.dumps(dict(flow_arn=flow_arn, output_name=output_name))
def update_flow_output(self): def update_flow_output(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
output_arn = unquote(self._get_param("outputArn")) output_arn = unquote(self._get_param("outputArn"))
cidr_allow_list = self._get_param("cidrAllowList") cidr_allow_list = self._get_param("cidrAllowList")
@ -133,7 +127,7 @@ class MediaConnectResponse(BaseResponse):
smoothing_latency = self._get_param("smoothingLatency") smoothing_latency = self._get_param("smoothingLatency")
stream_id = self._get_param("streamId") stream_id = self._get_param("streamId")
vpc_interface_attachment = self._get_param("vpcInterfaceAttachment") vpc_interface_attachment = self._get_param("vpcInterfaceAttachment")
flow_arn, output = self.mediaconnect_backend.update_flow_output( output = self.mediaconnect_backend.update_flow_output(
flow_arn=flow_arn, flow_arn=flow_arn,
output_arn=output_arn, output_arn=output_arn,
cidr_allow_list=cidr_allow_list, cidr_allow_list=cidr_allow_list,
@ -154,15 +148,15 @@ class MediaConnectResponse(BaseResponse):
) )
return json.dumps(dict(flowArn=flow_arn, output=output)) return json.dumps(dict(flowArn=flow_arn, output=output))
def add_flow_sources(self): def add_flow_sources(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
sources = self._get_param("sources") sources = self._get_param("sources")
flow_arn, sources = self.mediaconnect_backend.add_flow_sources( sources = self.mediaconnect_backend.add_flow_sources(
flow_arn=flow_arn, sources=sources flow_arn=flow_arn, sources=sources
) )
return json.dumps(dict(flow_arn=flow_arn, sources=sources)) return json.dumps(dict(flow_arn=flow_arn, sources=sources))
def update_flow_source(self): def update_flow_source(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
source_arn = unquote(self._get_param("sourceArn")) source_arn = unquote(self._get_param("sourceArn"))
description = self._get_param("description") description = self._get_param("description")
@ -182,7 +176,7 @@ class MediaConnectResponse(BaseResponse):
stream_id = self._get_param("streamId") stream_id = self._get_param("streamId")
vpc_interface_name = self._get_param("vpcInterfaceName") vpc_interface_name = self._get_param("vpcInterfaceName")
whitelist_cidr = self._get_param("whitelistCidr") whitelist_cidr = self._get_param("whitelistCidr")
flow_arn, source = self.mediaconnect_backend.update_flow_source( source = self.mediaconnect_backend.update_flow_source(
flow_arn=flow_arn, flow_arn=flow_arn,
source_arn=source_arn, source_arn=source_arn,
decryption=decryption, decryption=decryption,
@ -203,23 +197,23 @@ class MediaConnectResponse(BaseResponse):
) )
return json.dumps(dict(flow_arn=flow_arn, source=source)) return json.dumps(dict(flow_arn=flow_arn, source=source))
def grant_flow_entitlements(self): def grant_flow_entitlements(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
entitlements = self._get_param("entitlements") entitlements = self._get_param("entitlements")
flow_arn, entitlements = self.mediaconnect_backend.grant_flow_entitlements( entitlements = self.mediaconnect_backend.grant_flow_entitlements(
flow_arn=flow_arn, entitlements=entitlements flow_arn=flow_arn, entitlements=entitlements
) )
return json.dumps(dict(flow_arn=flow_arn, entitlements=entitlements)) return json.dumps(dict(flow_arn=flow_arn, entitlements=entitlements))
def revoke_flow_entitlement(self): def revoke_flow_entitlement(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
entitlement_arn = unquote(self._get_param("entitlementArn")) entitlement_arn = unquote(self._get_param("entitlementArn"))
flow_arn, entitlement_arn = self.mediaconnect_backend.revoke_flow_entitlement( self.mediaconnect_backend.revoke_flow_entitlement(
flow_arn=flow_arn, entitlement_arn=entitlement_arn flow_arn=flow_arn, entitlement_arn=entitlement_arn
) )
return json.dumps(dict(flowArn=flow_arn, entitlementArn=entitlement_arn)) return json.dumps(dict(flowArn=flow_arn, entitlementArn=entitlement_arn))
def update_flow_entitlement(self): def update_flow_entitlement(self) -> str:
flow_arn = unquote(self._get_param("flowArn")) flow_arn = unquote(self._get_param("flowArn"))
entitlement_arn = unquote(self._get_param("entitlementArn")) entitlement_arn = unquote(self._get_param("entitlementArn"))
description = self._get_param("description") description = self._get_param("description")
@ -227,7 +221,7 @@ class MediaConnectResponse(BaseResponse):
entitlement_status = self._get_param("entitlementStatus") entitlement_status = self._get_param("entitlementStatus")
name = self._get_param("name") name = self._get_param("name")
subscribers = self._get_param("subscribers") subscribers = self._get_param("subscribers")
flow_arn, entitlement = self.mediaconnect_backend.update_flow_entitlement( entitlement = self.mediaconnect_backend.update_flow_entitlement(
flow_arn=flow_arn, flow_arn=flow_arn,
entitlement_arn=entitlement_arn, entitlement_arn=entitlement_arn,
description=description, description=description,

View File

@ -1,11 +1,12 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
class Input(BaseModel): class Input(BaseModel):
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn") self.arn = kwargs.get("arn")
self.attached_channels = kwargs.get("attached_channels", []) self.attached_channels = kwargs.get("attached_channels", [])
self.destinations = kwargs.get("destinations", []) self.destinations = kwargs.get("destinations", [])
@ -23,8 +24,8 @@ class Input(BaseModel):
self.tags = kwargs.get("tags") self.tags = kwargs.get("tags")
self.input_type = kwargs.get("input_type") self.input_type = kwargs.get("input_type")
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
data = { return {
"arn": self.arn, "arn": self.arn,
"attachedChannels": self.attached_channels, "attachedChannels": self.attached_channels,
"destinations": self.destinations, "destinations": self.destinations,
@ -41,9 +42,8 @@ class Input(BaseModel):
"tags": self.tags, "tags": self.tags,
"type": self.input_type, "type": self.input_type,
} }
return data
def _resolve_transient_states(self): def _resolve_transient_states(self) -> None:
# Resolve transient states before second call # Resolve transient states before second call
# (to simulate AWS taking its sweet time with these things) # (to simulate AWS taking its sweet time with these things)
if self.state in ["CREATING"]: if self.state in ["CREATING"]:
@ -53,7 +53,7 @@ class Input(BaseModel):
class Channel(BaseModel): class Channel(BaseModel):
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn") self.arn = kwargs.get("arn")
self.cdi_input_specification = kwargs.get("cdi_input_specification") self.cdi_input_specification = kwargs.get("cdi_input_specification")
self.channel_class = kwargs.get("channel_class", "STANDARD") self.channel_class = kwargs.get("channel_class", "STANDARD")
@ -71,7 +71,7 @@ class Channel(BaseModel):
self.tags = kwargs.get("tags") self.tags = kwargs.get("tags")
self._previous_state = None self._previous_state = None
def to_dict(self, exclude=None): def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]:
data = { data = {
"arn": self.arn, "arn": self.arn,
"cdiInputSpecification": self.cdi_input_specification, "cdiInputSpecification": self.cdi_input_specification,
@ -97,7 +97,7 @@ class Channel(BaseModel):
del data[key] del data[key]
return data return data
def _resolve_transient_states(self): def _resolve_transient_states(self) -> None:
# Resolve transient states before second call # Resolve transient states before second call
# (to simulate AWS taking its sweet time with these things) # (to simulate AWS taking its sweet time with these things)
if self.state in ["CREATING", "STOPPING"]: if self.state in ["CREATING", "STOPPING"]:
@ -112,24 +112,24 @@ class Channel(BaseModel):
class MediaLiveBackend(BaseBackend): class MediaLiveBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self._channels = OrderedDict() self._channels: Dict[str, Channel] = OrderedDict()
self._inputs = OrderedDict() self._inputs: Dict[str, Input] = OrderedDict()
def create_channel( def create_channel(
self, self,
cdi_input_specification, cdi_input_specification: Dict[str, Any],
channel_class, channel_class: str,
destinations, destinations: List[Dict[str, Any]],
encoder_settings, encoder_settings: Dict[str, Any],
input_attachments, input_attachments: List[Dict[str, Any]],
input_specification, input_specification: Dict[str, str],
log_level, log_level: str,
name, name: str,
role_arn, role_arn: str,
tags, tags: Dict[str, str],
): ) -> Channel:
""" """
The RequestID and Reserved parameters are not yet implemented The RequestID and Reserved parameters are not yet implemented
""" """
@ -155,47 +155,49 @@ class MediaLiveBackend(BaseBackend):
self._channels[channel_id] = channel self._channels[channel_id] = channel
return channel return channel
def list_channels(self, max_results, next_token): def list_channels(self, max_results: Optional[int]) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
channels = list(self._channels.values()) channels = list(self._channels.values())
if max_results is not None: if max_results is not None:
channels = channels[:max_results] channels = channels[:max_results]
response_channels = [ return [
c.to_dict(exclude=["encoderSettings", "pipelineDetails"]) for c in channels c.to_dict(exclude=["encoderSettings", "pipelineDetails"]) for c in channels
] ]
return response_channels, next_token
def describe_channel(self, channel_id): def describe_channel(self, channel_id: str) -> Channel:
channel = self._channels[channel_id] channel = self._channels[channel_id]
channel._resolve_transient_states() channel._resolve_transient_states()
return channel.to_dict() return channel
def delete_channel(self, channel_id): def delete_channel(self, channel_id: str) -> Channel:
channel = self._channels[channel_id] channel = self._channels[channel_id]
channel.state = "DELETING" channel.state = "DELETING"
return channel.to_dict() return channel
def start_channel(self, channel_id): def start_channel(self, channel_id: str) -> Channel:
channel = self._channels[channel_id] channel = self._channels[channel_id]
channel.state = "STARTING" channel.state = "STARTING"
return channel.to_dict() return channel
def stop_channel(self, channel_id): def stop_channel(self, channel_id: str) -> Channel:
channel = self._channels[channel_id] channel = self._channels[channel_id]
channel.state = "STOPPING" channel.state = "STOPPING"
return channel.to_dict() return channel
def update_channel( def update_channel(
self, self,
channel_id, channel_id: str,
cdi_input_specification, cdi_input_specification: Dict[str, str],
destinations, destinations: List[Dict[str, Any]],
encoder_settings, encoder_settings: Dict[str, Any],
input_attachments, input_attachments: List[Dict[str, Any]],
input_specification, input_specification: Dict[str, str],
log_level, log_level: str,
name, name: str,
role_arn, role_arn: str,
): ) -> Channel:
channel = self._channels[channel_id] channel = self._channels[channel_id]
channel.cdi_input_specification = cdi_input_specification channel.cdi_input_specification = cdi_input_specification
channel.destinations = destinations channel.destinations = destinations
@ -214,16 +216,16 @@ class MediaLiveBackend(BaseBackend):
def create_input( def create_input(
self, self,
destinations, destinations: List[Dict[str, str]],
input_devices, input_devices: List[Dict[str, str]],
input_security_groups, input_security_groups: List[str],
media_connect_flows, media_connect_flows: List[Dict[str, str]],
name, name: str,
role_arn, role_arn: str,
sources, sources: List[Dict[str, str]],
tags, tags: Dict[str, str],
input_type, input_type: str,
): ) -> Input:
""" """
The VPC and RequestId parameters are not yet implemented The VPC and RequestId parameters are not yet implemented
""" """
@ -246,34 +248,35 @@ class MediaLiveBackend(BaseBackend):
self._inputs[input_id] = a_input self._inputs[input_id] = a_input
return a_input return a_input
def describe_input(self, input_id): def describe_input(self, input_id: str) -> Input:
a_input = self._inputs[input_id] a_input = self._inputs[input_id]
a_input._resolve_transient_states() a_input._resolve_transient_states()
return a_input.to_dict() return a_input
def list_inputs(self, max_results, next_token): def list_inputs(self, max_results: Optional[int]) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
inputs = list(self._inputs.values()) inputs = list(self._inputs.values())
if max_results is not None: if max_results is not None:
inputs = inputs[:max_results] inputs = inputs[:max_results]
response_inputs = [i.to_dict() for i in inputs] return [i.to_dict() for i in inputs]
return response_inputs, next_token
def delete_input(self, input_id): def delete_input(self, input_id: str) -> None:
a_input = self._inputs[input_id] a_input = self._inputs[input_id]
a_input.state = "DELETING" a_input.state = "DELETING"
return a_input.to_dict()
def update_input( def update_input(
self, self,
destinations, destinations: List[Dict[str, str]],
input_devices, input_devices: List[Dict[str, str]],
input_id, input_id: str,
input_security_groups, input_security_groups: List[str],
media_connect_flows, media_connect_flows: List[Dict[str, str]],
name, name: str,
role_arn, role_arn: str,
sources, sources: List[Dict[str, str]],
): ) -> Input:
a_input = self._inputs[input_id] a_input = self._inputs[input_id]
a_input.destinations = destinations a_input.destinations = destinations
a_input.input_devices = input_devices a_input.input_devices = input_devices

View File

@ -1,17 +1,17 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import medialive_backends from .models import medialive_backends, MediaLiveBackend
import json import json
class MediaLiveResponse(BaseResponse): class MediaLiveResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="medialive") super().__init__(service_name="medialive")
@property @property
def medialive_backend(self): def medialive_backend(self) -> MediaLiveBackend:
return medialive_backends[self.current_account][self.region] return medialive_backends[self.current_account][self.region]
def create_channel(self): def create_channel(self) -> str:
cdi_input_specification = self._get_param("cdiInputSpecification") cdi_input_specification = self._get_param("cdiInputSpecification")
channel_class = self._get_param("channelClass") channel_class = self._get_param("channelClass")
destinations = self._get_param("destinations") destinations = self._get_param("destinations")
@ -39,34 +39,33 @@ class MediaLiveResponse(BaseResponse):
dict(channel=channel.to_dict(exclude=["pipelinesRunningCount"])) dict(channel=channel.to_dict(exclude=["pipelinesRunningCount"]))
) )
def list_channels(self): def list_channels(self) -> str:
max_results = self._get_int_param("maxResults") max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken") channels = self.medialive_backend.list_channels(max_results=max_results)
channels, next_token = self.medialive_backend.list_channels(
max_results=max_results, next_token=next_token
)
return json.dumps(dict(channels=channels, nextToken=next_token)) return json.dumps(dict(channels=channels, nextToken=None))
def describe_channel(self): def describe_channel(self) -> str:
channel_id = self._get_param("channelId") channel_id = self._get_param("channelId")
return json.dumps( channel = self.medialive_backend.describe_channel(channel_id=channel_id)
self.medialive_backend.describe_channel(channel_id=channel_id) return json.dumps(channel.to_dict())
)
def delete_channel(self): def delete_channel(self) -> str:
channel_id = self._get_param("channelId") channel_id = self._get_param("channelId")
return json.dumps(self.medialive_backend.delete_channel(channel_id=channel_id)) channel = self.medialive_backend.delete_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def start_channel(self): def start_channel(self) -> str:
channel_id = self._get_param("channelId") channel_id = self._get_param("channelId")
return json.dumps(self.medialive_backend.start_channel(channel_id=channel_id)) channel = self.medialive_backend.start_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def stop_channel(self): def stop_channel(self) -> str:
channel_id = self._get_param("channelId") channel_id = self._get_param("channelId")
return json.dumps(self.medialive_backend.stop_channel(channel_id=channel_id)) channel = self.medialive_backend.stop_channel(channel_id=channel_id)
return json.dumps(channel.to_dict())
def update_channel(self): def update_channel(self) -> str:
channel_id = self._get_param("channelId") channel_id = self._get_param("channelId")
cdi_input_specification = self._get_param("cdiInputSpecification") cdi_input_specification = self._get_param("cdiInputSpecification")
destinations = self._get_param("destinations") destinations = self._get_param("destinations")
@ -89,7 +88,7 @@ class MediaLiveResponse(BaseResponse):
) )
return json.dumps(dict(channel=channel.to_dict())) return json.dumps(dict(channel=channel.to_dict()))
def create_input(self): def create_input(self) -> str:
destinations = self._get_param("destinations") destinations = self._get_param("destinations")
input_devices = self._get_param("inputDevices") input_devices = self._get_param("inputDevices")
input_security_groups = self._get_param("inputSecurityGroups") input_security_groups = self._get_param("inputSecurityGroups")
@ -112,25 +111,23 @@ class MediaLiveResponse(BaseResponse):
) )
return json.dumps({"input": a_input.to_dict()}) return json.dumps({"input": a_input.to_dict()})
def describe_input(self): def describe_input(self) -> str:
input_id = self._get_param("inputId") input_id = self._get_param("inputId")
return json.dumps(self.medialive_backend.describe_input(input_id=input_id)) a_input = self.medialive_backend.describe_input(input_id=input_id)
return json.dumps(a_input.to_dict())
def list_inputs(self): def list_inputs(self) -> str:
max_results = self._get_int_param("maxResults") max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken") inputs = self.medialive_backend.list_inputs(max_results=max_results)
inputs, next_token = self.medialive_backend.list_inputs(
max_results=max_results, next_token=next_token
)
return json.dumps(dict(inputs=inputs, nextToken=next_token)) return json.dumps(dict(inputs=inputs, nextToken=None))
def delete_input(self): def delete_input(self) -> str:
input_id = self._get_param("inputId") input_id = self._get_param("inputId")
self.medialive_backend.delete_input(input_id=input_id) self.medialive_backend.delete_input(input_id=input_id)
return json.dumps({}) return json.dumps({})
def update_input(self): def update_input(self) -> str:
destinations = self._get_param("destinations") destinations = self._get_param("destinations")
input_devices = self._get_param("inputDevices") input_devices = self._get_param("inputDevices")
input_id = self._get_param("inputId") input_id = self._get_param("inputId")

View File

@ -7,5 +7,5 @@ class MediaPackageClientError(JsonRESTError):
# AWS service exceptions are caught with the underlying botocore exception, ClientError # AWS service exceptions are caught with the underlying botocore exception, ClientError
class ClientError(MediaPackageClientError): class ClientError(MediaPackageClientError):
def __init__(self, error, message): def __init__(self, error: str, message: str):
super().__init__(error, message) super().__init__(error, message)

View File

@ -1,4 +1,5 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
@ -6,27 +7,23 @@ from .exceptions import ClientError
class Channel(BaseModel): class Channel(BaseModel):
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn") self.arn = kwargs.get("arn")
self.channel_id = kwargs.get("channel_id") self.channel_id = kwargs.get("channel_id")
self.description = kwargs.get("description") self.description = kwargs.get("description")
self.tags = kwargs.get("tags") self.tags = kwargs.get("tags")
def to_dict(self, exclude=None): def to_dict(self) -> Dict[str, Any]:
data = { return {
"arn": self.arn, "arn": self.arn,
"id": self.channel_id, "id": self.channel_id,
"description": self.description, "description": self.description,
"tags": self.tags, "tags": self.tags,
} }
if exclude:
for key in exclude:
del data[key]
return data
class OriginEndpoint(BaseModel): class OriginEndpoint(BaseModel):
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn") self.arn = kwargs.get("arn")
self.authorization = kwargs.get("authorization") self.authorization = kwargs.get("authorization")
self.channel_id = kwargs.get("channel_id") self.channel_id = kwargs.get("channel_id")
@ -44,8 +41,8 @@ class OriginEndpoint(BaseModel):
self.url = kwargs.get("url") self.url = kwargs.get("url")
self.whitelist = kwargs.get("whitelist") self.whitelist = kwargs.get("whitelist")
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
data = { return {
"arn": self.arn, "arn": self.arn,
"authorization": self.authorization, "authorization": self.authorization,
"channelId": self.channel_id, "channelId": self.channel_id,
@ -63,69 +60,62 @@ class OriginEndpoint(BaseModel):
"url": self.url, "url": self.url,
"whitelist": self.whitelist, "whitelist": self.whitelist,
} }
return data
class MediaPackageBackend(BaseBackend): class MediaPackageBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self._channels = OrderedDict() self._channels: Dict[str, Channel] = OrderedDict()
self._origin_endpoints = OrderedDict() self._origin_endpoints: Dict[str, OriginEndpoint] = OrderedDict()
def create_channel(self, description, channel_id, tags): def create_channel(
self, description: str, channel_id: str, tags: Dict[str, str]
) -> Channel:
arn = f"arn:aws:mediapackage:channel:{channel_id}" arn = f"arn:aws:mediapackage:channel:{channel_id}"
channel = Channel( channel = Channel(
arn=arn, arn=arn,
description=description, description=description,
egress_access_logs={},
hls_ingest={},
channel_id=channel_id, channel_id=channel_id,
ingress_access_logs={},
tags=tags, tags=tags,
) )
self._channels[channel_id] = channel self._channels[channel_id] = channel
return channel return channel
def list_channels(self): def list_channels(self) -> List[Dict[str, Any]]:
channels = list(self._channels.values()) return [c.to_dict() for c in self._channels.values()]
response_channels = [c.to_dict() for c in channels]
return response_channels
def describe_channel(self, channel_id): def describe_channel(self, channel_id: str) -> Channel:
try: try:
channel = self._channels[channel_id] return self._channels[channel_id]
return channel.to_dict()
except KeyError: except KeyError:
error = "NotFoundException" raise ClientError(
raise ClientError(error, f"channel with id={channel_id} not found") "NotFoundException", f"channel with id={channel_id} not found"
)
def delete_channel(self, channel_id): def delete_channel(self, channel_id: str) -> Channel:
try: if channel_id in self._channels:
channel = self._channels[channel_id] return self._channels.pop(channel_id)
del self._channels[channel_id] raise ClientError(
return channel.to_dict() "NotFoundException", f"channel with id={channel_id} not found"
)
except KeyError:
error = "NotFoundException"
raise ClientError(error, f"channel with id={channel_id} not found")
def create_origin_endpoint( def create_origin_endpoint(
self, self,
authorization, authorization: Dict[str, str],
channel_id, channel_id: str,
cmaf_package, cmaf_package: Dict[str, Any],
dash_package, dash_package: Dict[str, Any],
description, description: str,
hls_package, hls_package: Dict[str, Any],
endpoint_id, endpoint_id: str,
manifest_name, manifest_name: str,
mss_package, mss_package: Dict[str, Any],
origination, origination: str,
startover_window_seconds, startover_window_seconds: int,
tags, tags: Dict[str, str],
time_delay_seconds, time_delay_seconds: int,
whitelist, whitelist: List[str],
): ) -> OriginEndpoint:
arn = f"arn:aws:mediapackage:origin_endpoint:{endpoint_id}" arn = f"arn:aws:mediapackage:origin_endpoint:{endpoint_id}"
url = f"https://origin-endpoint.mediapackage.{self.region_name}.amazonaws.com/{endpoint_id}" url = f"https://origin-endpoint.mediapackage.{self.region_name}.amazonaws.com/{endpoint_id}"
origin_endpoint = OriginEndpoint( origin_endpoint = OriginEndpoint(
@ -149,43 +139,39 @@ class MediaPackageBackend(BaseBackend):
self._origin_endpoints[endpoint_id] = origin_endpoint self._origin_endpoints[endpoint_id] = origin_endpoint
return origin_endpoint return origin_endpoint
def describe_origin_endpoint(self, endpoint_id): def describe_origin_endpoint(self, endpoint_id: str) -> OriginEndpoint:
try: try:
origin_endpoint = self._origin_endpoints[endpoint_id] return self._origin_endpoints[endpoint_id]
return origin_endpoint.to_dict()
except KeyError: except KeyError:
error = "NotFoundException" raise ClientError(
raise ClientError(error, f"origin endpoint with id={endpoint_id} not found") "NotFoundException", f"origin endpoint with id={endpoint_id} not found"
)
def list_origin_endpoints(self): def list_origin_endpoints(self) -> List[Dict[str, Any]]:
origin_endpoints = list(self._origin_endpoints.values()) return [o.to_dict() for o in self._origin_endpoints.values()]
response_origin_endpoints = [o.to_dict() for o in origin_endpoints]
return response_origin_endpoints
def delete_origin_endpoint(self, endpoint_id): def delete_origin_endpoint(self, endpoint_id: str) -> OriginEndpoint:
try: if endpoint_id in self._origin_endpoints:
origin_endpoint = self._origin_endpoints[endpoint_id] return self._origin_endpoints.pop(endpoint_id)
del self._origin_endpoints[endpoint_id] raise ClientError(
return origin_endpoint.to_dict() "NotFoundException", f"origin endpoint with id={endpoint_id} not found"
except KeyError: )
error = "NotFoundException"
raise ClientError(error, f"origin endpoint with id={endpoint_id} not found")
def update_origin_endpoint( def update_origin_endpoint(
self, self,
authorization, authorization: Dict[str, str],
cmaf_package, cmaf_package: Dict[str, Any],
dash_package, dash_package: Dict[str, Any],
description, description: str,
hls_package, hls_package: Dict[str, Any],
endpoint_id, endpoint_id: str,
manifest_name, manifest_name: str,
mss_package, mss_package: Dict[str, Any],
origination, origination: str,
startover_window_seconds, startover_window_seconds: int,
time_delay_seconds, time_delay_seconds: int,
whitelist, whitelist: List[str],
): ) -> OriginEndpoint:
try: try:
origin_endpoint = self._origin_endpoints[endpoint_id] origin_endpoint = self._origin_endpoints[endpoint_id]
origin_endpoint.authorization = authorization origin_endpoint.authorization = authorization
@ -200,10 +186,10 @@ class MediaPackageBackend(BaseBackend):
origin_endpoint.time_delay_seconds = time_delay_seconds origin_endpoint.time_delay_seconds = time_delay_seconds
origin_endpoint.whitelist = whitelist origin_endpoint.whitelist = whitelist
return origin_endpoint return origin_endpoint
except KeyError: except KeyError:
error = "NotFoundException" raise ClientError(
raise ClientError(error, f"origin endpoint with id={endpoint_id} not found") "NotFoundException", f"origin endpoint with id={endpoint_id} not found"
)
mediapackage_backends = BackendDict(MediaPackageBackend, "mediapackage") mediapackage_backends = BackendDict(MediaPackageBackend, "mediapackage")

View File

@ -1,17 +1,17 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import mediapackage_backends from .models import mediapackage_backends, MediaPackageBackend
import json import json
class MediaPackageResponse(BaseResponse): class MediaPackageResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="mediapackage") super().__init__(service_name="mediapackage")
@property @property
def mediapackage_backend(self): def mediapackage_backend(self) -> MediaPackageBackend:
return mediapackage_backends[self.current_account][self.region] return mediapackage_backends[self.current_account][self.region]
def create_channel(self): def create_channel(self) -> str:
description = self._get_param("description") description = self._get_param("description")
channel_id = self._get_param("id") channel_id = self._get_param("id")
tags = self._get_param("tags") tags = self._get_param("tags")
@ -20,23 +20,21 @@ class MediaPackageResponse(BaseResponse):
) )
return json.dumps(channel.to_dict()) return json.dumps(channel.to_dict())
def list_channels(self): def list_channels(self) -> str:
channels = self.mediapackage_backend.list_channels() channels = self.mediapackage_backend.list_channels()
return json.dumps(dict(channels=channels)) return json.dumps(dict(channels=channels))
def describe_channel(self): def describe_channel(self) -> str:
channel_id = self._get_param("id") channel_id = self._get_param("id")
return json.dumps( channel = self.mediapackage_backend.describe_channel(channel_id=channel_id)
self.mediapackage_backend.describe_channel(channel_id=channel_id) return json.dumps(channel.to_dict())
)
def delete_channel(self): def delete_channel(self) -> str:
channel_id = self._get_param("id") channel_id = self._get_param("id")
return json.dumps( channel = self.mediapackage_backend.delete_channel(channel_id=channel_id)
self.mediapackage_backend.delete_channel(channel_id=channel_id) return json.dumps(channel.to_dict())
)
def create_origin_endpoint(self): def create_origin_endpoint(self) -> str:
authorization = self._get_param("authorization") authorization = self._get_param("authorization")
channel_id = self._get_param("channelId") channel_id = self._get_param("channelId")
cmaf_package = self._get_param("cmafPackage") cmaf_package = self._get_param("cmafPackage")
@ -65,27 +63,29 @@ class MediaPackageResponse(BaseResponse):
startover_window_seconds=startover_window_seconds, startover_window_seconds=startover_window_seconds,
tags=tags, tags=tags,
time_delay_seconds=time_delay_seconds, time_delay_seconds=time_delay_seconds,
whitelist=whitelist, whitelist=whitelist, # type: ignore[arg-type]
) )
return json.dumps(origin_endpoint.to_dict()) return json.dumps(origin_endpoint.to_dict())
def list_origin_endpoints(self): def list_origin_endpoints(self) -> str:
origin_endpoints = self.mediapackage_backend.list_origin_endpoints() origin_endpoints = self.mediapackage_backend.list_origin_endpoints()
return json.dumps(dict(originEndpoints=origin_endpoints)) return json.dumps(dict(originEndpoints=origin_endpoints))
def describe_origin_endpoint(self): def describe_origin_endpoint(self) -> str:
endpoint_id = self._get_param("id") endpoint_id = self._get_param("id")
return json.dumps( endpoint = self.mediapackage_backend.describe_origin_endpoint(
self.mediapackage_backend.describe_origin_endpoint(endpoint_id=endpoint_id) endpoint_id=endpoint_id
) )
return json.dumps(endpoint.to_dict())
def delete_origin_endpoint(self): def delete_origin_endpoint(self) -> str:
endpoint_id = self._get_param("id") endpoint_id = self._get_param("id")
return json.dumps( endpoint = self.mediapackage_backend.delete_origin_endpoint(
self.mediapackage_backend.delete_origin_endpoint(endpoint_id=endpoint_id) endpoint_id=endpoint_id
) )
return json.dumps(endpoint.to_dict())
def update_origin_endpoint(self): def update_origin_endpoint(self) -> str:
authorization = self._get_param("authorization") authorization = self._get_param("authorization")
cmaf_package = self._get_param("cmafPackage") cmaf_package = self._get_param("cmafPackage")
dash_package = self._get_param("dashPackage") dash_package = self._get_param("dashPackage")
@ -110,6 +110,6 @@ class MediaPackageResponse(BaseResponse):
origination=origination, origination=origination,
startover_window_seconds=startover_window_seconds, startover_window_seconds=startover_window_seconds,
time_delay_seconds=time_delay_seconds, time_delay_seconds=time_delay_seconds,
whitelist=whitelist, whitelist=whitelist, # type: ignore[arg-type]
) )
return json.dumps(origin_endpoint.to_dict()) return json.dumps(origin_endpoint.to_dict())

View File

@ -1,3 +1,4 @@
from typing import Optional
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
@ -6,7 +7,7 @@ class MediaStoreClientError(JsonRESTError):
class ContainerNotFoundException(MediaStoreClientError): class ContainerNotFoundException(MediaStoreClientError):
def __init__(self, msg=None): def __init__(self, msg: Optional[str] = None):
self.code = 400 self.code = 400
super().__init__( super().__init__(
"ContainerNotFoundException", "ContainerNotFoundException",
@ -15,7 +16,7 @@ class ContainerNotFoundException(MediaStoreClientError):
class ResourceNotFoundException(MediaStoreClientError): class ResourceNotFoundException(MediaStoreClientError):
def __init__(self, msg=None): def __init__(self, msg: Optional[str] = None):
self.code = 400 self.code = 400
super().__init__( super().__init__(
"ResourceNotFoundException", msg or "The specified container does not exist" "ResourceNotFoundException", msg or "The specified container does not exist"
@ -23,7 +24,7 @@ class ResourceNotFoundException(MediaStoreClientError):
class PolicyNotFoundException(MediaStoreClientError): class PolicyNotFoundException(MediaStoreClientError):
def __init__(self, msg=None): def __init__(self, msg: Optional[str] = None):
self.code = 400 self.code = 400
super().__init__( super().__init__(
"PolicyNotFoundException", "PolicyNotFoundException",

View File

@ -1,5 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from datetime import date from datetime import date
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import ( from .exceptions import (
@ -10,18 +11,18 @@ from .exceptions import (
class Container(BaseModel): class Container(BaseModel):
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
self.arn = kwargs.get("arn") self.arn = kwargs.get("arn")
self.name = kwargs.get("name") self.name = kwargs.get("name")
self.endpoint = kwargs.get("endpoint") self.endpoint = kwargs.get("endpoint")
self.status = kwargs.get("status") self.status = kwargs.get("status")
self.creation_time = kwargs.get("creation_time") self.creation_time = kwargs.get("creation_time")
self.lifecycle_policy = None self.lifecycle_policy: Optional[str] = None
self.policy = None self.policy: Optional[str] = None
self.metric_policy = None self.metric_policy: Optional[str] = None
self.tags = kwargs.get("tags") self.tags = kwargs.get("tags")
def to_dict(self, exclude=None): def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]:
data = { data = {
"ARN": self.arn, "ARN": self.arn,
"Name": self.name, "Name": self.name,
@ -37,11 +38,11 @@ class Container(BaseModel):
class MediaStoreBackend(BaseBackend): class MediaStoreBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self._containers = OrderedDict() self._containers: Dict[str, Container] = OrderedDict()
def create_container(self, name, tags): def create_container(self, name: str, tags: Dict[str, str]) -> Container:
arn = f"arn:aws:mediastore:container:{name}" arn = f"arn:aws:mediastore:container:{name}"
container = Container( container = Container(
arn=arn, arn=arn,
@ -54,40 +55,36 @@ class MediaStoreBackend(BaseBackend):
self._containers[name] = container self._containers[name] = container
return container return container
def delete_container(self, name): def delete_container(self, name: str) -> None:
if name not in self._containers: if name not in self._containers:
raise ContainerNotFoundException() raise ContainerNotFoundException()
del self._containers[name] del self._containers[name]
return {}
def describe_container(self, name): def describe_container(self, name: str) -> Container:
if name not in self._containers: if name not in self._containers:
raise ResourceNotFoundException() raise ResourceNotFoundException()
container = self._containers[name] container = self._containers[name]
container.status = "ACTIVE" container.status = "ACTIVE"
return container return container
def list_containers(self): def list_containers(self) -> List[Dict[str, Any]]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
containers = list(self._containers.values()) return [c.to_dict() for c in self._containers.values()]
response_containers = [c.to_dict() for c in containers]
return response_containers, None
def list_tags_for_resource(self, name): def list_tags_for_resource(self, name: str) -> Optional[Dict[str, str]]:
if name not in self._containers: if name not in self._containers:
raise ContainerNotFoundException() raise ContainerNotFoundException()
tags = self._containers[name].tags tags = self._containers[name].tags
return tags return tags
def put_lifecycle_policy(self, container_name, lifecycle_policy): def put_lifecycle_policy(self, container_name: str, lifecycle_policy: str) -> None:
if container_name not in self._containers: if container_name not in self._containers:
raise ResourceNotFoundException() raise ResourceNotFoundException()
self._containers[container_name].lifecycle_policy = lifecycle_policy self._containers[container_name].lifecycle_policy = lifecycle_policy
return {}
def get_lifecycle_policy(self, container_name): def get_lifecycle_policy(self, container_name: str) -> str:
if container_name not in self._containers: if container_name not in self._containers:
raise ResourceNotFoundException() raise ResourceNotFoundException()
lifecycle_policy = self._containers[container_name].lifecycle_policy lifecycle_policy = self._containers[container_name].lifecycle_policy
@ -95,13 +92,12 @@ class MediaStoreBackend(BaseBackend):
raise PolicyNotFoundException() raise PolicyNotFoundException()
return lifecycle_policy return lifecycle_policy
def put_container_policy(self, container_name, policy): def put_container_policy(self, container_name: str, policy: str) -> None:
if container_name not in self._containers: if container_name not in self._containers:
raise ResourceNotFoundException() raise ResourceNotFoundException()
self._containers[container_name].policy = policy self._containers[container_name].policy = policy
return {}
def get_container_policy(self, container_name): def get_container_policy(self, container_name: str) -> str:
if container_name not in self._containers: if container_name not in self._containers:
raise ResourceNotFoundException() raise ResourceNotFoundException()
policy = self._containers[container_name].policy policy = self._containers[container_name].policy
@ -109,13 +105,12 @@ class MediaStoreBackend(BaseBackend):
raise PolicyNotFoundException() raise PolicyNotFoundException()
return policy return policy
def put_metric_policy(self, container_name, metric_policy): def put_metric_policy(self, container_name: str, metric_policy: str) -> None:
if container_name not in self._containers: if container_name not in self._containers:
raise ResourceNotFoundException() raise ResourceNotFoundException()
self._containers[container_name].metric_policy = metric_policy self._containers[container_name].metric_policy = metric_policy
return {}
def get_metric_policy(self, container_name): def get_metric_policy(self, container_name: str) -> str:
if container_name not in self._containers: if container_name not in self._containers:
raise ResourceNotFoundException() raise ResourceNotFoundException()
metric_policy = self._containers[container_name].metric_policy metric_policy = self._containers[container_name].metric_policy

View File

@ -1,73 +1,73 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import mediastore_backends from .models import mediastore_backends, MediaStoreBackend
class MediaStoreResponse(BaseResponse): class MediaStoreResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="mediastore") super().__init__(service_name="mediastore")
@property @property
def mediastore_backend(self): def mediastore_backend(self) -> MediaStoreBackend:
return mediastore_backends[self.current_account][self.region] return mediastore_backends[self.current_account][self.region]
def create_container(self): def create_container(self) -> str:
name = self._get_param("ContainerName") name = self._get_param("ContainerName")
tags = self._get_param("Tags") tags = self._get_param("Tags")
container = self.mediastore_backend.create_container(name=name, tags=tags) container = self.mediastore_backend.create_container(name=name, tags=tags)
return json.dumps(dict(Container=container.to_dict())) return json.dumps(dict(Container=container.to_dict()))
def delete_container(self): def delete_container(self) -> str:
name = self._get_param("ContainerName") name = self._get_param("ContainerName")
result = self.mediastore_backend.delete_container(name=name) self.mediastore_backend.delete_container(name=name)
return json.dumps(result) return "{}"
def describe_container(self): def describe_container(self) -> str:
name = self._get_param("ContainerName") name = self._get_param("ContainerName")
container = self.mediastore_backend.describe_container(name=name) container = self.mediastore_backend.describe_container(name=name)
return json.dumps(dict(Container=container.to_dict())) return json.dumps(dict(Container=container.to_dict()))
def list_containers(self): def list_containers(self) -> str:
containers, next_token = self.mediastore_backend.list_containers() containers = self.mediastore_backend.list_containers()
return json.dumps(dict(dict(Containers=containers), NextToken=next_token)) return json.dumps(dict(dict(Containers=containers), NextToken=None))
def list_tags_for_resource(self): def list_tags_for_resource(self) -> str:
name = self._get_param("Resource") name = self._get_param("Resource")
tags = self.mediastore_backend.list_tags_for_resource(name) tags = self.mediastore_backend.list_tags_for_resource(name)
return json.dumps(dict(Tags=tags)) return json.dumps(dict(Tags=tags))
def put_lifecycle_policy(self): def put_lifecycle_policy(self) -> str:
container_name = self._get_param("ContainerName") container_name = self._get_param("ContainerName")
lifecycle_policy = self._get_param("LifecyclePolicy") lifecycle_policy = self._get_param("LifecyclePolicy")
policy = self.mediastore_backend.put_lifecycle_policy( self.mediastore_backend.put_lifecycle_policy(
container_name=container_name, lifecycle_policy=lifecycle_policy container_name=container_name, lifecycle_policy=lifecycle_policy
) )
return json.dumps(policy) return "{}"
def get_lifecycle_policy(self): def get_lifecycle_policy(self) -> str:
container_name = self._get_param("ContainerName") container_name = self._get_param("ContainerName")
lifecycle_policy = self.mediastore_backend.get_lifecycle_policy( lifecycle_policy = self.mediastore_backend.get_lifecycle_policy(
container_name=container_name container_name=container_name
) )
return json.dumps(dict(LifecyclePolicy=lifecycle_policy)) return json.dumps(dict(LifecyclePolicy=lifecycle_policy))
def put_container_policy(self): def put_container_policy(self) -> str:
container_name = self._get_param("ContainerName") container_name = self._get_param("ContainerName")
policy = self._get_param("Policy") policy = self._get_param("Policy")
container_policy = self.mediastore_backend.put_container_policy( self.mediastore_backend.put_container_policy(
container_name=container_name, policy=policy container_name=container_name, policy=policy
) )
return json.dumps(container_policy) return "{}"
def get_container_policy(self): def get_container_policy(self) -> str:
container_name = self._get_param("ContainerName") container_name = self._get_param("ContainerName")
policy = self.mediastore_backend.get_container_policy( policy = self.mediastore_backend.get_container_policy(
container_name=container_name container_name=container_name
) )
return json.dumps(dict(Policy=policy)) return json.dumps(dict(Policy=policy))
def put_metric_policy(self): def put_metric_policy(self) -> str:
container_name = self._get_param("ContainerName") container_name = self._get_param("ContainerName")
metric_policy = self._get_param("MetricPolicy") metric_policy = self._get_param("MetricPolicy")
self.mediastore_backend.put_metric_policy( self.mediastore_backend.put_metric_policy(
@ -75,12 +75,9 @@ class MediaStoreResponse(BaseResponse):
) )
return json.dumps(metric_policy) return json.dumps(metric_policy)
def get_metric_policy(self): def get_metric_policy(self) -> str:
container_name = self._get_param("ContainerName") container_name = self._get_param("ContainerName")
metric_policy = self.mediastore_backend.get_metric_policy( metric_policy = self.mediastore_backend.get_metric_policy(
container_name=container_name container_name=container_name
) )
return json.dumps(dict(MetricPolicy=metric_policy)) return json.dumps(dict(MetricPolicy=metric_policy))
# add templates from here

View File

@ -7,5 +7,5 @@ class MediaStoreDataClientError(JsonRESTError):
# AWS service exceptions are caught with the underlying botocore exception, ClientError # AWS service exceptions are caught with the underlying botocore exception, ClientError
class ClientError(MediaStoreDataClientError): class ClientError(MediaStoreDataClientError):
def __init__(self, error, message): def __init__(self, error: str, message: str):
super().__init__(error, message) super().__init__(error, message)

View File

@ -1,20 +1,23 @@
import hashlib import hashlib
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import ClientError from .exceptions import ClientError
class Object(BaseModel): class Object(BaseModel):
def __init__(self, path, body, etag, storage_class="TEMPORAL"): def __init__(
self, path: str, body: str, etag: str, storage_class: str = "TEMPORAL"
):
self.path = path self.path = path
self.body = body self.body = body
self.content_sha256 = hashlib.sha256(body.encode("utf-8")).hexdigest() self.content_sha256 = hashlib.sha256(body.encode("utf-8")).hexdigest()
self.etag = etag self.etag = etag
self.storage_class = storage_class self.storage_class = storage_class
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
data = { return {
"ETag": self.etag, "ETag": self.etag,
"Name": self.path, "Name": self.path,
"Type": "FILE", "Type": "FILE",
@ -24,15 +27,15 @@ class Object(BaseModel):
"ContentSHA256": self.content_sha256, "ContentSHA256": self.content_sha256,
} }
return data
class MediaStoreDataBackend(BaseBackend): class MediaStoreDataBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self._objects = OrderedDict() self._objects: Dict[str, Object] = OrderedDict()
def put_object(self, body, path, storage_class="TEMPORAL"): def put_object(
self, body: str, path: str, storage_class: str = "TEMPORAL"
) -> Object:
""" """
The following parameters are not yet implemented: ContentType, CacheControl, UploadAvailability The following parameters are not yet implemented: ContentType, CacheControl, UploadAvailability
""" """
@ -42,30 +45,29 @@ class MediaStoreDataBackend(BaseBackend):
self._objects[path] = new_object self._objects[path] = new_object
return new_object return new_object
def delete_object(self, path): def delete_object(self, path: str) -> None:
if path not in self._objects: if path not in self._objects:
error = "ObjectNotFoundException" raise ClientError(
raise ClientError(error, f"Object with id={path} not found") "ObjectNotFoundException", f"Object with id={path} not found"
)
del self._objects[path] del self._objects[path]
return {}
def get_object(self, path): def get_object(self, path: str) -> Object:
""" """
The Range-parameter is not yet supported. The Range-parameter is not yet supported.
""" """
objects_found = [item for item in self._objects.values() if item.path == path] objects_found = [item for item in self._objects.values() if item.path == path]
if len(objects_found) == 0: if len(objects_found) == 0:
error = "ObjectNotFoundException" raise ClientError(
raise ClientError(error, f"Object with id={path} not found") "ObjectNotFoundException", f"Object with id={path} not found"
)
return objects_found[0] return objects_found[0]
def list_items(self): def list_items(self) -> List[Dict[str, Any]]:
""" """
The Path- and MaxResults-parameters are not yet supported. The Path- and MaxResults-parameters are not yet supported.
""" """
items = self._objects.values() return [c.to_dict() for c in self._objects.values()]
response_items = [c.to_dict() for c in items]
return response_items
mediastoredata_backends = BackendDict(MediaStoreDataBackend, "mediastore-data") mediastoredata_backends = BackendDict(MediaStoreDataBackend, "mediastore-data")

View File

@ -1,35 +1,36 @@
import json import json
from typing import Dict, Tuple
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import mediastoredata_backends from .models import mediastoredata_backends, MediaStoreDataBackend
class MediaStoreDataResponse(BaseResponse): class MediaStoreDataResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="mediastore-data") super().__init__(service_name="mediastore-data")
@property @property
def mediastoredata_backend(self): def mediastoredata_backend(self) -> MediaStoreDataBackend:
return mediastoredata_backends[self.current_account][self.region] return mediastoredata_backends[self.current_account][self.region]
def get_object(self): def get_object(self) -> Tuple[str, Dict[str, str]]:
path = self._get_param("Path") path = self._get_param("Path")
result = self.mediastoredata_backend.get_object(path=path) result = self.mediastoredata_backend.get_object(path=path)
headers = {"Path": result.path} headers = {"Path": result.path}
return result.body, headers return result.body, headers
def put_object(self): def put_object(self) -> str:
body = self.body body = self.body
path = self._get_param("Path") path = self._get_param("Path")
new_object = self.mediastoredata_backend.put_object(body, path) new_object = self.mediastoredata_backend.put_object(body, path)
object_dict = new_object.to_dict() object_dict = new_object.to_dict()
return json.dumps(object_dict) return json.dumps(object_dict)
def delete_object(self): def delete_object(self) -> str:
item_id = self._get_param("Path") item_id = self._get_param("Path")
result = self.mediastoredata_backend.delete_object(path=item_id) self.mediastoredata_backend.delete_object(path=item_id)
return json.dumps(result) return "{}"
def list_items(self): def list_items(self) -> str:
items = self.mediastoredata_backend.list_items() items = self.mediastoredata_backend.list_items()
return json.dumps(dict(Items=items)) return json.dumps(dict(Items=items))

View File

@ -1,38 +0,0 @@
from moto.core.exceptions import JsonRESTError
class DisabledApiException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="DisabledApiException", message=message)
class InternalServiceErrorException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="InternalServiceErrorException", message=message)
class InvalidCustomerIdentifierException(JsonRESTError):
def __init__(self, message):
super().__init__(
error_type="InvalidCustomerIdentifierException", message=message
)
class InvalidProductCodeException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="InvalidProductCodeException", message=message)
class InvalidUsageDimensionException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="InvalidUsageDimensionException", message=message)
class ThrottlingException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="ThrottlingException", message=message)
class TimestampOutOfBoundsException(JsonRESTError):
def __init__(self, message):
super().__init__(error_type="TimestampOutOfBoundsException", message=message)

View File

@ -1,10 +1,17 @@
import collections import collections
from typing import Any, Deque, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
class UsageRecord(BaseModel, dict): class UsageRecord(BaseModel, Dict[str, Any]): # type: ignore[misc]
def __init__(self, timestamp, customer_identifier, dimension, quantity=0): def __init__(
self,
timestamp: str,
customer_identifier: str,
dimension: str,
quantity: int = 0,
):
super().__init__() super().__init__()
self.timestamp = timestamp self.timestamp = timestamp
self.customer_identifier = customer_identifier self.customer_identifier = customer_identifier
@ -12,54 +19,45 @@ class UsageRecord(BaseModel, dict):
self.quantity = quantity self.quantity = quantity
self.metering_record_id = mock_random.uuid4().hex self.metering_record_id = mock_random.uuid4().hex
@classmethod
def from_data(cls, data):
cls(
timestamp=data.get("Timestamp"),
customer_identifier=data.get("CustomerIdentifier"),
dimension=data.get("Dimension"),
quantity=data.get("Quantity", 0),
)
@property @property
def timestamp(self): def timestamp(self) -> str:
return self["Timestamp"] return self["Timestamp"]
@timestamp.setter @timestamp.setter
def timestamp(self, value): def timestamp(self, value: str) -> None:
self["Timestamp"] = value self["Timestamp"] = value
@property @property
def customer_identifier(self): def customer_identifier(self) -> str:
return self["CustomerIdentifier"] return self["CustomerIdentifier"]
@customer_identifier.setter @customer_identifier.setter
def customer_identifier(self, value): def customer_identifier(self, value: str) -> None:
self["CustomerIdentifier"] = value self["CustomerIdentifier"] = value
@property @property
def dimension(self): def dimension(self) -> str:
return self["Dimension"] return self["Dimension"]
@dimension.setter @dimension.setter
def dimension(self, value): def dimension(self, value: str) -> None:
self["Dimension"] = value self["Dimension"] = value
@property @property
def quantity(self): def quantity(self) -> int:
return self["Quantity"] return self["Quantity"]
@quantity.setter @quantity.setter
def quantity(self, value): def quantity(self, value: int) -> None:
self["Quantity"] = value self["Quantity"] = value
class Result(BaseModel, dict): class Result(BaseModel, Dict[str, Any]): # type: ignore[misc]
SUCCESS = "Success" SUCCESS = "Success"
CUSTOMER_NOT_SUBSCRIBED = "CustomerNotSubscribed" CUSTOMER_NOT_SUBSCRIBED = "CustomerNotSubscribed"
DUPLICATE_RECORD = "DuplicateRecord" DUPLICATE_RECORD = "DuplicateRecord"
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
self.usage_record = UsageRecord( self.usage_record = UsageRecord(
timestamp=kwargs["Timestamp"], timestamp=kwargs["Timestamp"],
customer_identifier=kwargs["CustomerIdentifier"], customer_identifier=kwargs["CustomerIdentifier"],
@ -70,28 +68,26 @@ class Result(BaseModel, dict):
self["MeteringRecordId"] = self.usage_record.metering_record_id self["MeteringRecordId"] = self.usage_record.metering_record_id
@property @property
def metering_record_id(self): def metering_record_id(self) -> str:
return self["MeteringRecordId"] return self["MeteringRecordId"]
@property @property
def status(self): def status(self) -> str:
return self["Status"] return self["Status"]
@status.setter @status.setter
def status(self, value): def status(self, value: str) -> None:
self["Status"] = value self["Status"] = value
@property @property
def usage_record(self): def usage_record(self) -> UsageRecord:
return self["UsageRecord"] return self["UsageRecord"]
@usage_record.setter @usage_record.setter
def usage_record(self, value): def usage_record(self, value: UsageRecord) -> None:
if not isinstance(value, UsageRecord):
value = UsageRecord.from_data(value)
self["UsageRecord"] = value self["UsageRecord"] = value
def is_duplicate(self, other): def is_duplicate(self, other: Any) -> bool:
""" """
DuplicateRecord - Indicates that the UsageRecord was invalid and not honored. DuplicateRecord - Indicates that the UsageRecord was invalid and not honored.
A previously metered UsageRecord had the same customer, dimension, and time, A previously metered UsageRecord had the same customer, dimension, and time,
@ -107,23 +103,29 @@ class Result(BaseModel, dict):
) )
class CustomerDeque(collections.deque): class CustomerDeque(Deque[str]):
def is_subscribed(self, customer): def is_subscribed(self, customer: str) -> bool:
return customer in self return customer in self
class ResultDeque(collections.deque): class ResultDeque(Deque[Result]):
def is_duplicate(self, result): def is_duplicate(self, result: Result) -> bool:
return any(record.is_duplicate(result) for record in self) return any(record.is_duplicate(result) for record in self)
class MeteringMarketplaceBackend(BaseBackend): class MeteringMarketplaceBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.customers_by_product = collections.defaultdict(CustomerDeque) self.customers_by_product: Dict[str, CustomerDeque] = collections.defaultdict(
self.records_by_product = collections.defaultdict(ResultDeque) CustomerDeque
)
self.records_by_product: Dict[str, ResultDeque] = collections.defaultdict(
ResultDeque
)
def batch_meter_usage(self, product_code, usage_records): def batch_meter_usage(
self, product_code: str, usage_records: List[Dict[str, Any]]
) -> List[Result]:
results = [] results = []
for usage in usage_records: for usage in usage_records:
result = Result(**usage) result = Result(**usage)

View File

@ -9,8 +9,7 @@ class MarketplaceMeteringResponse(BaseResponse):
def backend(self) -> MeteringMarketplaceBackend: def backend(self) -> MeteringMarketplaceBackend:
return meteringmarketplace_backends[self.current_account][self.region] return meteringmarketplace_backends[self.current_account][self.region]
def batch_meter_usage(self): def batch_meter_usage(self) -> str:
results = []
usage_records = json.loads(self.body)["UsageRecords"] usage_records = json.loads(self.body)["UsageRecords"]
product_code = json.loads(self.body)["ProductCode"] product_code = json.loads(self.body)["ProductCode"]
results = self.backend.batch_meter_usage(product_code, usage_records) results = self.backend.batch_meter_usage(product_code, usage_records)

View File

@ -1,29 +1,32 @@
import time import time
from threading import Thread from threading import Thread
from werkzeug.serving import make_server from typing import Optional
from werkzeug.serving import make_server, BaseWSGIServer
from .werkzeug_app import DomainDispatcherApplication, create_backend_app from .werkzeug_app import DomainDispatcherApplication, create_backend_app
class ThreadedMotoServer: class ThreadedMotoServer:
def __init__(self, ip_address="0.0.0.0", port=5000, verbose=True): def __init__(
self, ip_address: str = "0.0.0.0", port: int = 5000, verbose: bool = True
):
self._port = port self._port = port
self._thread = None self._thread: Optional[Thread] = None
self._ip_address = ip_address self._ip_address = ip_address
self._server = None self._server: Optional[BaseWSGIServer] = None
self._server_ready = False self._server_ready = False
self._verbose = verbose self._verbose = verbose
def _server_entry(self): def _server_entry(self) -> None:
app = DomainDispatcherApplication(create_backend_app) app = DomainDispatcherApplication(create_backend_app)
self._server = make_server(self._ip_address, self._port, app, True) self._server = make_server(self._ip_address, self._port, app, True)
self._server_ready = True self._server_ready = True
self._server.serve_forever() self._server.serve_forever()
def start(self): def start(self) -> None:
if self._verbose: if self._verbose:
print( # noqa print( # noqa
f"Starting a new Thread with MotoServer running on {self._ip_address}:{self._port}..." f"Starting a new Thread with MotoServer running on {self._ip_address}:{self._port}..."
@ -33,9 +36,9 @@ class ThreadedMotoServer:
while not self._server_ready: while not self._server_ready:
time.sleep(0.1) time.sleep(0.1)
def stop(self): def stop(self) -> None:
self._server_ready = False self._server_ready = False
if self._server: if self._server:
self._server.shutdown() self._server.shutdown()
self._thread.join() self._thread.join() # type: ignore[union-attr]

View File

@ -1,6 +1,6 @@
import json import json
from flask.testing import FlaskClient from flask.testing import FlaskClient
from typing import Any, Dict
from urllib.parse import urlencode from urllib.parse import urlencode
from werkzeug.routing import BaseConverter from werkzeug.routing import BaseConverter
@ -10,13 +10,13 @@ class RegexConverter(BaseConverter):
part_isolating = False part_isolating = False
def __init__(self, url_map, *items): def __init__(self, url_map: Any, *items: Any):
super().__init__(url_map) super().__init__(url_map)
self.regex = items[0] self.regex = items[0]
class AWSTestHelper(FlaskClient): class AWSTestHelper(FlaskClient):
def action_data(self, action_name, **kwargs): def action_data(self, action_name: str, **kwargs: Any) -> str:
""" """
Method calls resource with action_name and returns data of response. Method calls resource with action_name and returns data of response.
""" """
@ -24,11 +24,11 @@ class AWSTestHelper(FlaskClient):
opts.update(kwargs) opts.update(kwargs)
res = self.get( res = self.get(
f"/?{urlencode(opts)}", f"/?{urlencode(opts)}",
headers={"Host": f"{self.application.service}.us-east-1.amazonaws.com"}, headers={"Host": f"{self.application.service}.us-east-1.amazonaws.com"}, # type: ignore[attr-defined]
) )
return res.data.decode("utf-8") return res.data.decode("utf-8")
def action_json(self, action_name, **kwargs): def action_json(self, action_name: str, **kwargs: Any) -> Dict[str, Any]:
""" """
Method calls resource with action_name and returns object obtained via Method calls resource with action_name and returns object obtained via
deserialization of output. deserialization of output.

View File

@ -2,6 +2,7 @@ import io
import os import os
import os.path import os.path
from threading import Lock from threading import Lock
from typing import Any, Callable, Dict, Optional, Tuple
try: try:
from flask import Flask from flask import Flask
@ -49,20 +50,22 @@ SIGNING_ALIASES = {
SERVICE_BY_VERSION = {"2009-04-15": "sdb"} SERVICE_BY_VERSION = {"2009-04-15": "sdb"}
class DomainDispatcherApplication(object): class DomainDispatcherApplication:
""" """
Dispatch requests to different applications based on the "Host:" header Dispatch requests to different applications based on the "Host:" header
value. We'll match the host header value with the url_bases of each backend. value. We'll match the host header value with the url_bases of each backend.
""" """
def __init__(self, create_app, service=None): def __init__(
self, create_app: Callable[[str], Flask], service: Optional[str] = None
):
self.create_app = create_app self.create_app = create_app
self.lock = Lock() self.lock = Lock()
self.app_instances = {} self.app_instances: Dict[str, Flask] = {}
self.service = service self.service = service
self.backend_url_patterns = backend_index.backend_url_patterns self.backend_url_patterns = backend_index.backend_url_patterns
def get_backend_for_host(self, host): def get_backend_for_host(self, host: str) -> Any:
if host == "moto_api": if host == "moto_api":
return host return host
@ -83,7 +86,9 @@ class DomainDispatcherApplication(object):
"Remember to add the URL to urls.py, and run scripts/update_backend_index.py to index it." "Remember to add the URL to urls.py, and run scripts/update_backend_index.py to index it."
) )
def infer_service_region_host(self, body, environ): def infer_service_region_host(
self, body: Optional[str], environ: Dict[str, Any]
) -> str:
auth = environ.get("HTTP_AUTHORIZATION") auth = environ.get("HTTP_AUTHORIZATION")
target = environ.get("HTTP_X_AMZ_TARGET") target = environ.get("HTTP_X_AMZ_TARGET")
service = None service = None
@ -111,7 +116,7 @@ class DomainDispatcherApplication(object):
service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION) service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION)
elif action and action in UNSIGNED_ACTIONS: elif action and action in UNSIGNED_ACTIONS:
# See if we can match the Action to a known service # See if we can match the Action to a known service
service, region = UNSIGNED_ACTIONS.get(action) service, region = UNSIGNED_ACTIONS[action]
if not service: if not service:
service, region = self.get_service_from_body(body, environ) service, region = self.get_service_from_body(body, environ)
if not service: if not service:
@ -153,7 +158,7 @@ class DomainDispatcherApplication(object):
return host return host
def get_application(self, environ): def get_application(self, environ: Dict[str, Any]) -> Flask:
path_info = environ.get("PATH_INFO", "") path_info = environ.get("PATH_INFO", "")
# The URL path might contain non-ASCII text, for instance unicode S3 bucket names # The URL path might contain non-ASCII text, for instance unicode S3 bucket names
@ -181,7 +186,7 @@ class DomainDispatcherApplication(object):
self.app_instances[backend] = app self.app_instances[backend] = app
return app return app
def _get_body(self, environ): def _get_body(self, environ: Dict[str, Any]) -> Optional[str]:
body = None body = None
try: try:
# AWS requests use querystrings as the body (Action=x&Data=y&...) # AWS requests use querystrings as the body (Action=x&Data=y&...)
@ -190,7 +195,7 @@ class DomainDispatcherApplication(object):
) )
request_body_size = int(environ["CONTENT_LENGTH"]) request_body_size = int(environ["CONTENT_LENGTH"])
if simple_form and request_body_size: if simple_form and request_body_size:
body = environ["wsgi.input"].read(request_body_size).decode("utf-8") body = environ["wsgi.input"].read(request_body_size).decode("utf-8") # type: ignore
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
finally: finally:
@ -199,7 +204,9 @@ class DomainDispatcherApplication(object):
environ["wsgi.input"] = io.StringIO(body) environ["wsgi.input"] = io.StringIO(body)
return body return body
def get_service_from_body(self, body, environ): def get_service_from_body(
self, body: Optional[str], environ: Dict[str, Any]
) -> Tuple[Optional[str], Optional[str]]:
# Some services have the SDK Version in the body # Some services have the SDK Version in the body
# If the version is unique, we can derive the service from it # If the version is unique, we can derive the service from it
version = self.get_version_from_body(body) version = self.get_version_from_body(body)
@ -209,22 +216,24 @@ class DomainDispatcherApplication(object):
return SERVICE_BY_VERSION[version], region return SERVICE_BY_VERSION[version], region
return None, None return None, None
def get_version_from_body(self, body): def get_version_from_body(self, body: Optional[str]) -> Optional[str]:
try: try:
body_dict = dict(x.split("=") for x in body.split("&")) body_dict = dict(x.split("=") for x in body.split("&")) # type: ignore
return body_dict["Version"] return body_dict["Version"]
except (AttributeError, KeyError, ValueError): except (AttributeError, KeyError, ValueError):
return None return None
def get_action_from_body(self, body): def get_action_from_body(self, body: Optional[str]) -> Optional[str]:
try: try:
# AWS requests use querystrings as the body (Action=x&Data=y&...) # AWS requests use querystrings as the body (Action=x&Data=y&...)
body_dict = dict(x.split("=") for x in body.split("&")) body_dict = dict(x.split("=") for x in body.split("&")) # type: ignore
return body_dict["Action"] return body_dict["Action"]
except (AttributeError, KeyError, ValueError): except (AttributeError, KeyError, ValueError):
return None return None
def get_service_from_path(self, environ): def get_service_from_path(
self, environ: Dict[str, Any]
) -> Tuple[Optional[str], Optional[str]]:
# Moto sometimes needs to send a HTTP request to itself # Moto sometimes needs to send a HTTP request to itself
# In which case it will send a request to 'http://localhost/service_region/whatever' # In which case it will send a request to 'http://localhost/service_region/whatever'
try: try:
@ -234,12 +243,12 @@ class DomainDispatcherApplication(object):
except (AttributeError, KeyError, ValueError): except (AttributeError, KeyError, ValueError):
return None, None return None, None
def __call__(self, environ, start_response): def __call__(self, environ: Dict[str, Any], start_response: Any) -> Any:
backend_app = self.get_application(environ) backend_app = self.get_application(environ)
return backend_app(environ, start_response) return backend_app(environ, start_response)
def create_backend_app(service): def create_backend_app(service: str) -> Flask:
from werkzeug.routing import Map from werkzeug.routing import Map
current_file = os.path.abspath(__file__) current_file = os.path.abspath(__file__)
@ -249,7 +258,7 @@ def create_backend_app(service):
# Create the backend_app # Create the backend_app
backend_app = Flask("moto", template_folder=template_dir) backend_app = Flask("moto", template_folder=template_dir)
backend_app.debug = True backend_app.debug = True
backend_app.service = service backend_app.service = service # type: ignore[attr-defined]
CORS(backend_app) CORS(backend_app)
# Reset view functions to reset the app # Reset view functions to reset the app

View File

@ -1,4 +1,5 @@
import json import json
from typing import Any
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
@ -7,11 +8,13 @@ class MQError(JsonRESTError):
class UnknownBroker(MQError): class UnknownBroker(MQError):
def __init__(self, broker_id): def __init__(self, broker_id: str):
super().__init__("NotFoundException", "Can't find requested broker") super().__init__("NotFoundException", "Can't find requested broker")
self.broker_id = broker_id self.broker_id = broker_id
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "broker-id", "errorAttribute": "broker-id",
"message": f"Can't find requested broker [{self.broker_id}]. Make sure your broker exists.", "message": f"Can't find requested broker [{self.broker_id}]. Make sure your broker exists.",
@ -20,11 +23,13 @@ class UnknownBroker(MQError):
class UnknownConfiguration(MQError): class UnknownConfiguration(MQError):
def __init__(self, config_id): def __init__(self, config_id: str):
super().__init__("NotFoundException", "Can't find requested configuration") super().__init__("NotFoundException", "Can't find requested configuration")
self.config_id = config_id self.config_id = config_id
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "configuration_id", "errorAttribute": "configuration_id",
"message": f"Can't find requested configuration [{self.config_id}]. Make sure your configuration exists.", "message": f"Can't find requested configuration [{self.config_id}]. Make sure your configuration exists.",
@ -33,11 +38,13 @@ class UnknownConfiguration(MQError):
class UnknownUser(MQError): class UnknownUser(MQError):
def __init__(self, username): def __init__(self, username: str):
super().__init__("NotFoundException", "Can't find requested user") super().__init__("NotFoundException", "Can't find requested user")
self.username = username self.username = username
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "username", "errorAttribute": "username",
"message": f"Can't find requested user [{self.username}]. Make sure your user exists.", "message": f"Can't find requested user [{self.username}]. Make sure your user exists.",
@ -46,11 +53,13 @@ class UnknownUser(MQError):
class UnsupportedEngineType(MQError): class UnsupportedEngineType(MQError):
def __init__(self, engine_type): def __init__(self, engine_type: str):
super().__init__("BadRequestException", "") super().__init__("BadRequestException", "")
self.engine_type = engine_type self.engine_type = engine_type
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "engineType", "errorAttribute": "engineType",
"message": f"Broker engine type [{self.engine_type}] does not support configuration.", "message": f"Broker engine type [{self.engine_type}] does not support configuration.",
@ -59,11 +68,13 @@ class UnsupportedEngineType(MQError):
class UnknownEngineType(MQError): class UnknownEngineType(MQError):
def __init__(self, engine_type): def __init__(self, engine_type: str):
super().__init__("BadRequestException", "") super().__init__("BadRequestException", "")
self.engine_type = engine_type self.engine_type = engine_type
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument def get_body(
self, *args: Any, **kwargs: Any
) -> str: # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "engineType", "errorAttribute": "engineType",
"message": f"Broker engine type [{self.engine_type}] is invalid. Valid values are: [ACTIVEMQ]", "message": f"Broker engine type [{self.engine_type}] is invalid. Valid values are: [ACTIVEMQ]",

View File

@ -1,5 +1,6 @@
import base64 import base64
import xmltodict import xmltodict
from typing import Any, Dict, List, Iterable, Optional, Tuple
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
@ -17,7 +18,13 @@ from .exceptions import (
class ConfigurationRevision(BaseModel): class ConfigurationRevision(BaseModel):
def __init__(self, configuration_id, revision_id, description, data=None): def __init__(
self,
configuration_id: str,
revision_id: str,
description: str,
data: Optional[str] = None,
):
self.configuration_id = configuration_id self.configuration_id = configuration_id
self.created = unix_time() self.created = unix_time()
self.description = description self.description = description
@ -31,7 +38,7 @@ class ConfigurationRevision(BaseModel):
else: else:
self.data = data self.data = data
def has_ldap_auth(self): def has_ldap_auth(self) -> bool:
try: try:
xml = base64.b64decode(self.data) xml = base64.b64decode(self.data)
dct = xmltodict.parse(xml, dict_constructor=dict) dct = xmltodict.parse(xml, dict_constructor=dict)
@ -45,7 +52,7 @@ class ConfigurationRevision(BaseModel):
# If anything fails, lets assume it's not LDAP # If anything fails, lets assume it's not LDAP
return False return False
def to_json(self, full=True): def to_json(self, full: bool = True) -> Dict[str, Any]:
resp = { resp = {
"created": self.created, "created": self.created,
"description": self.description, "description": self.description,
@ -58,7 +65,14 @@ class ConfigurationRevision(BaseModel):
class Configuration(BaseModel): class Configuration(BaseModel):
def __init__(self, account_id, region, name, engine_type, engine_version): def __init__(
self,
account_id: str,
region: str,
name: str,
engine_type: str,
engine_version: str,
):
self.id = f"c-{mock_random.get_random_hex(6)}" self.id = f"c-{mock_random.get_random_hex(6)}"
self.arn = f"arn:aws:mq:{region}:{account_id}:configuration:{self.id}" self.arn = f"arn:aws:mq:{region}:{account_id}:configuration:{self.id}"
self.created = unix_time() self.created = unix_time()
@ -67,7 +81,7 @@ class Configuration(BaseModel):
self.engine_type = engine_type self.engine_type = engine_type
self.engine_version = engine_version self.engine_version = engine_version
self.revisions = dict() self.revisions: Dict[str, ConfigurationRevision] = dict()
default_desc = ( default_desc = (
f"Auto-generated default for {self.name} on {engine_type} {engine_version}" f"Auto-generated default for {self.name} on {engine_type} {engine_version}"
) )
@ -80,7 +94,7 @@ class Configuration(BaseModel):
"ldap" if latest_revision.has_ldap_auth() else "simple" "ldap" if latest_revision.has_ldap_auth() else "simple"
) )
def update(self, data, description): def update(self, data: str, description: str) -> None:
max_revision_id, _ = sorted(self.revisions.items())[-1] max_revision_id, _ = sorted(self.revisions.items())[-1]
next_revision_id = str(int(max_revision_id) + 1) next_revision_id = str(int(max_revision_id) + 1)
latest_revision = ConfigurationRevision( latest_revision = ConfigurationRevision(
@ -95,10 +109,10 @@ class Configuration(BaseModel):
"ldap" if latest_revision.has_ldap_auth() else "simple" "ldap" if latest_revision.has_ldap_auth() else "simple"
) )
def get_revision(self, revision_id): def get_revision(self, revision_id: str) -> ConfigurationRevision:
return self.revisions[revision_id] return self.revisions[revision_id]
def to_json(self): def to_json(self) -> Dict[str, Any]:
_, latest_revision = sorted(self.revisions.items())[-1] _, latest_revision = sorted(self.revisions.items())[-1]
return { return {
"arn": self.arn, "arn": self.arn,
@ -113,22 +127,30 @@ class Configuration(BaseModel):
class User(BaseModel): class User(BaseModel):
def __init__(self, broker_id, username, console_access=None, groups=None): def __init__(
self,
broker_id: str,
username: str,
console_access: Optional[bool] = None,
groups: Optional[List[str]] = None,
):
self.broker_id = broker_id self.broker_id = broker_id
self.username = username self.username = username
self.console_access = console_access or False self.console_access = console_access or False
self.groups = groups or [] self.groups = groups or []
def update(self, console_access, groups): def update(
self, console_access: Optional[bool], groups: Optional[List[str]]
) -> None:
if console_access is not None: if console_access is not None:
self.console_access = console_access self.console_access = console_access
if groups: if groups:
self.groups = groups self.groups = groups
def summary(self): def summary(self) -> Dict[str, str]:
return {"username": self.username} return {"username": self.username}
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"brokerId": self.broker_id, "brokerId": self.broker_id,
"username": self.username, "username": self.username,
@ -140,25 +162,25 @@ class User(BaseModel):
class Broker(BaseModel): class Broker(BaseModel):
def __init__( def __init__(
self, self,
name, name: str,
account_id, account_id: str,
region, region: str,
authentication_strategy, authentication_strategy: str,
auto_minor_version_upgrade, auto_minor_version_upgrade: bool,
configuration, configuration: Dict[str, Any],
deployment_mode, deployment_mode: str,
encryption_options, encryption_options: Dict[str, Any],
engine_type, engine_type: str,
engine_version, engine_version: str,
host_instance_type, host_instance_type: str,
ldap_server_metadata, ldap_server_metadata: Dict[str, Any],
logs, logs: Dict[str, bool],
maintenance_window_start_time, maintenance_window_start_time: Dict[str, str],
publicly_accessible, publicly_accessible: bool,
security_groups, security_groups: List[str],
storage_type, storage_type: str,
subnet_ids, subnet_ids: List[str],
users, users: List[Dict[str, Any]],
): ):
self.name = name self.name = name
self.id = mock_random.get_random_hex(6) self.id = mock_random.get_random_hex(6)
@ -206,7 +228,7 @@ class Broker(BaseModel):
else: else:
self.subnet_ids = ["default-subnet"] self.subnet_ids = ["default-subnet"]
self.users = dict() self.users: Dict[str, User] = dict()
for user in users: for user in users:
self.create_user( self.create_user(
username=user["username"], username=user["username"],
@ -215,13 +237,13 @@ class Broker(BaseModel):
) )
if self.engine_type.upper() == "RABBITMQ": if self.engine_type.upper() == "RABBITMQ":
self.configurations = None self.configurations: Optional[Dict[str, Any]] = None
else: else:
current_config = configuration or { current_config = configuration or {
"id": f"c-{mock_random.get_random_hex(6)}", "id": f"c-{mock_random.get_random_hex(6)}",
"revision": 1, "revision": 1,
} }
self.configurations = { self.configurations = { # type: ignore[no-redef]
"current": current_config, "current": current_config,
"history": [], "history": [],
} }
@ -256,23 +278,23 @@ class Broker(BaseModel):
def update( def update(
self, self,
authentication_strategy, authentication_strategy: Optional[str],
auto_minor_version_upgrade, auto_minor_version_upgrade: Optional[bool],
configuration, configuration: Optional[Dict[str, Any]],
engine_version, engine_version: Optional[str],
host_instance_type, host_instance_type: Optional[str],
ldap_server_metadata, ldap_server_metadata: Optional[Dict[str, Any]],
logs, logs: Optional[Dict[str, bool]],
maintenance_window_start_time, maintenance_window_start_time: Optional[Dict[str, str]],
security_groups, security_groups: Optional[List[str]],
): ) -> None:
if authentication_strategy: if authentication_strategy:
self.authentication_strategy = authentication_strategy self.authentication_strategy = authentication_strategy
if auto_minor_version_upgrade is not None: if auto_minor_version_upgrade is not None:
self.auto_minor_version_upgrade = auto_minor_version_upgrade self.auto_minor_version_upgrade = auto_minor_version_upgrade
if configuration: if configuration:
self.configurations["history"].append(self.configurations["current"]) self.configurations["history"].append(self.configurations["current"]) # type: ignore[index]
self.configurations["current"] = configuration self.configurations["current"] = configuration # type: ignore[index]
if engine_version: if engine_version:
self.engine_version = engine_version self.engine_version = engine_version
if host_instance_type: if host_instance_type:
@ -286,29 +308,33 @@ class Broker(BaseModel):
if security_groups: if security_groups:
self.security_groups = security_groups self.security_groups = security_groups
def reboot(self): def reboot(self) -> None:
pass pass
def create_user(self, username, console_access, groups): def create_user(
self, username: str, console_access: bool, groups: List[str]
) -> None:
user = User(self.id, username, console_access, groups) user = User(self.id, username, console_access, groups)
self.users[username] = user self.users[username] = user
def update_user(self, username, console_access, groups): def update_user(
self, username: str, console_access: bool, groups: List[str]
) -> None:
user = self.get_user(username) user = self.get_user(username)
user.update(console_access, groups) user.update(console_access, groups)
def get_user(self, username): def get_user(self, username: str) -> User:
if username not in self.users: if username not in self.users:
raise UnknownUser(username) raise UnknownUser(username)
return self.users[username] return self.users[username]
def delete_user(self, username): def delete_user(self, username: str) -> None:
self.users.pop(username, None) self.users.pop(username, None)
def list_users(self): def list_users(self) -> Iterable[User]:
return self.users.values() return self.users.values()
def summary(self): def summary(self) -> Dict[str, Any]:
return { return {
"brokerArn": self.arn, "brokerArn": self.arn,
"brokerId": self.id, "brokerId": self.id,
@ -320,7 +346,7 @@ class Broker(BaseModel):
"hostInstanceType": self.host_instance_type, "hostInstanceType": self.host_instance_type,
} }
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"brokerId": self.id, "brokerId": self.id,
"brokerArn": self.arn, "brokerArn": self.arn,
@ -352,33 +378,33 @@ class MQBackend(BaseBackend):
No EC2 integration exists yet - subnet ID's and security group values are not validated. Default values may not exist. No EC2 integration exists yet - subnet ID's and security group values are not validated. Default values may not exist.
""" """
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.brokers = dict() self.brokers: Dict[str, Broker] = dict()
self.configs = dict() self.configs: Dict[str, Configuration] = dict()
self.tagger = TaggingService() self.tagger = TaggingService()
def create_broker( def create_broker(
self, self,
authentication_strategy, authentication_strategy: str,
auto_minor_version_upgrade, auto_minor_version_upgrade: bool,
broker_name, broker_name: str,
configuration, configuration: Dict[str, Any],
deployment_mode, deployment_mode: str,
encryption_options, encryption_options: Dict[str, Any],
engine_type, engine_type: str,
engine_version, engine_version: str,
host_instance_type, host_instance_type: str,
ldap_server_metadata, ldap_server_metadata: Dict[str, Any],
logs, logs: Dict[str, bool],
maintenance_window_start_time, maintenance_window_start_time: Dict[str, str],
publicly_accessible, publicly_accessible: bool,
security_groups, security_groups: List[str],
storage_type, storage_type: str,
subnet_ids, subnet_ids: List[str],
tags, tags: Dict[str, str],
users, users: List[Dict[str, Any]],
): ) -> Tuple[str, str]:
broker = Broker( broker = Broker(
name=broker_name, name=broker_name,
account_id=self.account_id, account_id=self.account_id,
@ -404,44 +430,50 @@ class MQBackend(BaseBackend):
self.create_tags(broker.arn, tags) self.create_tags(broker.arn, tags)
return broker.arn, broker.id return broker.arn, broker.id
def delete_broker(self, broker_id): def delete_broker(self, broker_id: str) -> None:
del self.brokers[broker_id] del self.brokers[broker_id]
def describe_broker(self, broker_id): def describe_broker(self, broker_id: str) -> Broker:
if broker_id not in self.brokers: if broker_id not in self.brokers:
raise UnknownBroker(broker_id) raise UnknownBroker(broker_id)
return self.brokers[broker_id] return self.brokers[broker_id]
def reboot_broker(self, broker_id): def reboot_broker(self, broker_id: str) -> None:
self.brokers[broker_id].reboot() self.brokers[broker_id].reboot()
def list_brokers(self): def list_brokers(self) -> Iterable[Broker]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
return self.brokers.values() return self.brokers.values()
def create_user(self, broker_id, username, console_access, groups): def create_user(
self, broker_id: str, username: str, console_access: bool, groups: List[str]
) -> None:
broker = self.describe_broker(broker_id) broker = self.describe_broker(broker_id)
broker.create_user(username, console_access, groups) broker.create_user(username, console_access, groups)
def update_user(self, broker_id, console_access, groups, username): def update_user(
self, broker_id: str, console_access: bool, groups: List[str], username: str
) -> None:
broker = self.describe_broker(broker_id) broker = self.describe_broker(broker_id)
broker.update_user(username, console_access, groups) broker.update_user(username, console_access, groups)
def describe_user(self, broker_id, username): def describe_user(self, broker_id: str, username: str) -> User:
broker = self.describe_broker(broker_id) broker = self.describe_broker(broker_id)
return broker.get_user(username) return broker.get_user(username)
def delete_user(self, broker_id, username): def delete_user(self, broker_id: str, username: str) -> None:
broker = self.describe_broker(broker_id) broker = self.describe_broker(broker_id)
broker.delete_user(username) broker.delete_user(username)
def list_users(self, broker_id): def list_users(self, broker_id: str) -> Iterable[User]:
broker = self.describe_broker(broker_id) broker = self.describe_broker(broker_id)
return broker.list_users() return broker.list_users()
def create_configuration(self, name, engine_type, engine_version, tags): def create_configuration(
self, name: str, engine_type: str, engine_version: str, tags: Dict[str, str]
) -> Configuration:
if engine_type.upper() == "RABBITMQ": if engine_type.upper() == "RABBITMQ":
raise UnsupportedEngineType(engine_type) raise UnsupportedEngineType(engine_type)
if engine_type.upper() != "ACTIVEMQ": if engine_type.upper() != "ACTIVEMQ":
@ -459,7 +491,9 @@ class MQBackend(BaseBackend):
) )
return config return config
def update_configuration(self, config_id, data, description): def update_configuration(
self, config_id: str, data: str, description: str
) -> Configuration:
""" """
No validation occurs on the provided XML. The authenticationStrategy may be changed depending on the provided configuration. No validation occurs on the provided XML. The authenticationStrategy may be changed depending on the provided configuration.
""" """
@ -467,47 +501,49 @@ class MQBackend(BaseBackend):
config.update(data, description) config.update(data, description)
return config return config
def describe_configuration(self, config_id): def describe_configuration(self, config_id: str) -> Configuration:
if config_id not in self.configs: if config_id not in self.configs:
raise UnknownConfiguration(config_id) raise UnknownConfiguration(config_id)
return self.configs[config_id] return self.configs[config_id]
def describe_configuration_revision(self, config_id, revision_id): def describe_configuration_revision(
self, config_id: str, revision_id: str
) -> ConfigurationRevision:
config = self.configs[config_id] config = self.configs[config_id]
return config.get_revision(revision_id) return config.get_revision(revision_id)
def list_configurations(self): def list_configurations(self) -> Iterable[Configuration]:
""" """
Pagination has not yet been implemented. Pagination has not yet been implemented.
""" """
return self.configs.values() return self.configs.values()
def create_tags(self, resource_arn, tags): def create_tags(self, resource_arn: str, tags: Dict[str, str]) -> None:
self.tagger.tag_resource( self.tagger.tag_resource(
resource_arn, self.tagger.convert_dict_to_tags_input(tags) resource_arn, self.tagger.convert_dict_to_tags_input(tags)
) )
def list_tags(self, arn): def list_tags(self, arn: str) -> Dict[str, str]:
return self.tagger.get_tag_dict_for_resource(arn) return self.tagger.get_tag_dict_for_resource(arn)
def delete_tags(self, resource_arn, tag_keys): def delete_tags(self, resource_arn: str, tag_keys: List[str]) -> None:
if not isinstance(tag_keys, list): if not isinstance(tag_keys, list):
tag_keys = [tag_keys] tag_keys = [tag_keys]
self.tagger.untag_resource_using_names(resource_arn, tag_keys) self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def update_broker( def update_broker(
self, self,
authentication_strategy, authentication_strategy: str,
auto_minor_version_upgrade, auto_minor_version_upgrade: bool,
broker_id, broker_id: str,
configuration, configuration: Dict[str, Any],
engine_version, engine_version: str,
host_instance_type, host_instance_type: str,
ldap_server_metadata, ldap_server_metadata: Dict[str, Any],
logs, logs: Dict[str, bool],
maintenance_window_start_time, maintenance_window_start_time: Dict[str, str],
security_groups, security_groups: List[str],
): ) -> None:
broker = self.describe_broker(broker_id) broker = self.describe_broker(broker_id)
broker.update( broker.update(
authentication_strategy=authentication_strategy, authentication_strategy=authentication_strategy,

View File

@ -1,23 +1,25 @@
"""Handles incoming mq requests, invokes methods, returns responses.""" """Handles incoming mq requests, invokes methods, returns responses."""
import json import json
from typing import Any
from urllib.parse import unquote from urllib.parse import unquote
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import mq_backends from .models import mq_backends, MQBackend
class MQResponse(BaseResponse): class MQResponse(BaseResponse):
"""Handler for MQ requests and responses.""" """Handler for MQ requests and responses."""
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="mq") super().__init__(service_name="mq")
@property @property
def mq_backend(self): def mq_backend(self) -> MQBackend:
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return mq_backends[self.current_account][self.region] return mq_backends[self.current_account][self.region]
def broker(self, request, full_url, headers): def broker(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.describe_broker() return self.describe_broker()
@ -26,40 +28,40 @@ class MQResponse(BaseResponse):
if request.method == "PUT": if request.method == "PUT":
return self.update_broker() return self.update_broker()
def brokers(self, request, full_url, headers): def brokers(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
return self.create_broker() return self.create_broker()
if request.method == "GET": if request.method == "GET":
return self.list_brokers() return self.list_brokers()
def configuration(self, request, full_url, headers): def configuration(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.describe_configuration() return self.describe_configuration()
if request.method == "PUT": if request.method == "PUT":
return self.update_configuration() return self.update_configuration()
def configurations(self, request, full_url, headers): def configurations(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
return self.create_configuration() return self.create_configuration()
if request.method == "GET": if request.method == "GET":
return self.list_configurations() return self.list_configurations()
def configuration_revision(self, request, full_url, headers): def configuration_revision(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.get_configuration_revision() return self.get_configuration_revision()
def tags(self, request, full_url, headers): def tags(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
return self.create_tags() return self.create_tags()
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_tags() return self.delete_tags()
def user(self, request, full_url, headers): def user(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
return self.create_user() return self.create_user()
@ -70,12 +72,12 @@ class MQResponse(BaseResponse):
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_user() return self.delete_user()
def users(self, request, full_url, headers): def users(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.list_users() return self.list_users()
def create_broker(self): def create_broker(self) -> TYPE_RESPONSE:
params = json.loads(self.body) params = json.loads(self.body)
authentication_strategy = params.get("authenticationStrategy") authentication_strategy = params.get("authenticationStrategy")
auto_minor_version_upgrade = params.get("autoMinorVersionUpgrade") auto_minor_version_upgrade = params.get("autoMinorVersionUpgrade")
@ -119,7 +121,7 @@ class MQResponse(BaseResponse):
resp = {"brokerArn": broker_arn, "brokerId": broker_id} resp = {"brokerArn": broker_arn, "brokerId": broker_id}
return 200, {}, json.dumps(resp) return 200, {}, json.dumps(resp)
def update_broker(self): def update_broker(self) -> TYPE_RESPONSE:
params = json.loads(self.body) params = json.loads(self.body)
broker_id = self.path.split("/")[-1] broker_id = self.path.split("/")[-1]
authentication_strategy = params.get("authenticationStrategy") authentication_strategy = params.get("authenticationStrategy")
@ -145,23 +147,23 @@ class MQResponse(BaseResponse):
) )
return self.describe_broker() return self.describe_broker()
def delete_broker(self): def delete_broker(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-1] broker_id = self.path.split("/")[-1]
self.mq_backend.delete_broker(broker_id=broker_id) self.mq_backend.delete_broker(broker_id=broker_id)
return 200, {}, json.dumps(dict(brokerId=broker_id)) return 200, {}, json.dumps(dict(brokerId=broker_id))
def describe_broker(self): def describe_broker(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-1] broker_id = self.path.split("/")[-1]
broker = self.mq_backend.describe_broker(broker_id=broker_id) broker = self.mq_backend.describe_broker(broker_id=broker_id)
resp = broker.to_json() resp = broker.to_json()
resp["tags"] = self.mq_backend.list_tags(broker.arn) resp["tags"] = self.mq_backend.list_tags(broker.arn)
return 200, {}, json.dumps(resp) return 200, {}, json.dumps(resp)
def list_brokers(self): def list_brokers(self) -> TYPE_RESPONSE:
brokers = self.mq_backend.list_brokers() brokers = self.mq_backend.list_brokers()
return 200, {}, json.dumps(dict(brokerSummaries=[b.summary() for b in brokers])) return 200, {}, json.dumps(dict(brokerSummaries=[b.summary() for b in brokers]))
def create_user(self): def create_user(self) -> TYPE_RESPONSE:
params = json.loads(self.body) params = json.loads(self.body)
broker_id = self.path.split("/")[-3] broker_id = self.path.split("/")[-3]
username = self.path.split("/")[-1] username = self.path.split("/")[-1]
@ -170,7 +172,7 @@ class MQResponse(BaseResponse):
self.mq_backend.create_user(broker_id, username, console_access, groups) self.mq_backend.create_user(broker_id, username, console_access, groups)
return 200, {}, "{}" return 200, {}, "{}"
def update_user(self): def update_user(self) -> TYPE_RESPONSE:
params = json.loads(self.body) params = json.loads(self.body)
broker_id = self.path.split("/")[-3] broker_id = self.path.split("/")[-3]
username = self.path.split("/")[-1] username = self.path.split("/")[-1]
@ -184,19 +186,19 @@ class MQResponse(BaseResponse):
) )
return 200, {}, "{}" return 200, {}, "{}"
def describe_user(self): def describe_user(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-3] broker_id = self.path.split("/")[-3]
username = self.path.split("/")[-1] username = self.path.split("/")[-1]
user = self.mq_backend.describe_user(broker_id, username) user = self.mq_backend.describe_user(broker_id, username)
return 200, {}, json.dumps(user.to_json()) return 200, {}, json.dumps(user.to_json())
def delete_user(self): def delete_user(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-3] broker_id = self.path.split("/")[-3]
username = self.path.split("/")[-1] username = self.path.split("/")[-1]
self.mq_backend.delete_user(broker_id, username) self.mq_backend.delete_user(broker_id, username)
return 200, {}, "{}" return 200, {}, "{}"
def list_users(self): def list_users(self) -> TYPE_RESPONSE:
broker_id = self.path.split("/")[-2] broker_id = self.path.split("/")[-2]
users = self.mq_backend.list_users(broker_id=broker_id) users = self.mq_backend.list_users(broker_id=broker_id)
resp = { resp = {
@ -205,7 +207,7 @@ class MQResponse(BaseResponse):
} }
return 200, {}, json.dumps(resp) return 200, {}, json.dumps(resp)
def create_configuration(self): def create_configuration(self) -> TYPE_RESPONSE:
params = json.loads(self.body) params = json.loads(self.body)
name = params.get("name") name = params.get("name")
engine_type = params.get("engineType") engine_type = params.get("engineType")
@ -217,19 +219,19 @@ class MQResponse(BaseResponse):
) )
return 200, {}, json.dumps(config.to_json()) return 200, {}, json.dumps(config.to_json())
def describe_configuration(self): def describe_configuration(self) -> TYPE_RESPONSE:
config_id = self.path.split("/")[-1] config_id = self.path.split("/")[-1]
config = self.mq_backend.describe_configuration(config_id) config = self.mq_backend.describe_configuration(config_id)
resp = config.to_json() resp = config.to_json()
resp["tags"] = self.mq_backend.list_tags(config.arn) resp["tags"] = self.mq_backend.list_tags(config.arn)
return 200, {}, json.dumps(resp) return 200, {}, json.dumps(resp)
def list_configurations(self): def list_configurations(self) -> TYPE_RESPONSE:
configs = self.mq_backend.list_configurations() configs = self.mq_backend.list_configurations()
resp = {"configurations": [c.to_json() for c in configs]} resp = {"configurations": [c.to_json() for c in configs]}
return 200, {}, json.dumps(resp) return 200, {}, json.dumps(resp)
def update_configuration(self): def update_configuration(self) -> TYPE_RESPONSE:
config_id = self.path.split("/")[-1] config_id = self.path.split("/")[-1]
params = json.loads(self.body) params = json.loads(self.body)
data = params.get("data") data = params.get("data")
@ -237,7 +239,7 @@ class MQResponse(BaseResponse):
config = self.mq_backend.update_configuration(config_id, data, description) config = self.mq_backend.update_configuration(config_id, data, description)
return 200, {}, json.dumps(config.to_json()) return 200, {}, json.dumps(config.to_json())
def get_configuration_revision(self): def get_configuration_revision(self) -> TYPE_RESPONSE:
revision_id = self.path.split("/")[-1] revision_id = self.path.split("/")[-1]
config_id = self.path.split("/")[-3] config_id = self.path.split("/")[-3]
revision = self.mq_backend.describe_configuration_revision( revision = self.mq_backend.describe_configuration_revision(
@ -245,19 +247,19 @@ class MQResponse(BaseResponse):
) )
return 200, {}, json.dumps(revision.to_json()) return 200, {}, json.dumps(revision.to_json())
def create_tags(self): def create_tags(self) -> TYPE_RESPONSE:
resource_arn = unquote(self.path.split("/")[-1]) resource_arn = unquote(self.path.split("/")[-1])
tags = json.loads(self.body).get("tags", {}) tags = json.loads(self.body).get("tags", {})
self.mq_backend.create_tags(resource_arn, tags) self.mq_backend.create_tags(resource_arn, tags)
return 200, {}, "{}" return 200, {}, "{}"
def delete_tags(self): def delete_tags(self) -> TYPE_RESPONSE:
resource_arn = unquote(self.path.split("/")[-1]) resource_arn = unquote(self.path.split("/")[-1])
tag_keys = self._get_param("tagKeys") tag_keys = self._get_param("tagKeys")
self.mq_backend.delete_tags(resource_arn, tag_keys) self.mq_backend.delete_tags(resource_arn, tag_keys)
return 200, {}, "{}" return 200, {}, "{}"
def reboot(self, request, full_url, headers): def reboot(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
broker_id = self.path.split("/")[-2] broker_id = self.path.split("/")[-2]

View File

@ -235,7 +235,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/managedblockchain,moto/moto_api,moto/neptune,moto/opensearch,moto/rdsdata files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/neptune,moto/opensearch,moto/rdsdata
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract

View File

@ -205,6 +205,31 @@ def test_start_stop_flow_succeeds():
describe_response["Flow"]["Status"].should.equal("STANDBY") describe_response["Flow"]["Status"].should.equal("STANDBY")
@mock_mediaconnect
def test_unknown_flow():
client = boto3.client("mediaconnect", region_name=region)
with pytest.raises(ClientError) as exc:
client.describe_flow(FlowArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
with pytest.raises(ClientError) as exc:
client.delete_flow(FlowArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
with pytest.raises(ClientError) as exc:
client.start_flow(FlowArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
with pytest.raises(ClientError) as exc:
client.stop_flow(FlowArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
with pytest.raises(ClientError) as exc:
client.list_tags_for_resource(ResourceArn="unknown")
assert exc.value.response["Error"]["Code"] == "NotFoundException"
@mock_mediaconnect @mock_mediaconnect
def test_tag_resource_succeeds(): def test_tag_resource_succeeds():
client = boto3.client("mediaconnect", region_name=region) client = boto3.client("mediaconnect", region_name=region)