diff --git a/moto/packages/boto/ec2/blockdevicemapping.py b/moto/packages/boto/ec2/blockdevicemapping.py index c88f8053d..50bcc21c3 100644 --- a/moto/packages/boto/ec2/blockdevicemapping.py +++ b/moto/packages/boto/ec2/blockdevicemapping.py @@ -63,7 +63,7 @@ class BlockDeviceType(object): EBSBlockDeviceType = BlockDeviceType -class BlockDeviceMapping(dict): +class BlockDeviceMapping(Dict[Any, Any]): """ Represents a collection of BlockDeviceTypes when creating ec2 instances. diff --git a/moto/packages/boto/ec2/ec2object.py b/moto/packages/boto/ec2/ec2object.py index 0067f59ce..f87f72f46 100644 --- a/moto/packages/boto/ec2/ec2object.py +++ b/moto/packages/boto/ec2/ec2object.py @@ -23,11 +23,12 @@ """ Represents an EC2 Object """ +from typing import Any from moto.packages.boto.ec2.tag import TagSet -class EC2Object(object): - def __init__(self, connection=None): +class EC2Object: + def __init__(self, connection: Any = None): self.connection = connection self.region = None @@ -43,6 +44,6 @@ class TaggedEC2Object(EC2Object): object. """ - def __init__(self, connection=None): + def __init__(self, connection: Any = None): super(TaggedEC2Object, self).__init__(connection) - self.tags = TagSet() + self.tags = TagSet() # type: ignore diff --git a/moto/packages/boto/ec2/image.py b/moto/packages/boto/ec2/image.py index b1fba4197..7f629f220 100644 --- a/moto/packages/boto/ec2/image.py +++ b/moto/packages/boto/ec2/image.py @@ -21,5 +21,5 @@ # IN THE SOFTWARE. -class ProductCodes(list): +class ProductCodes(list): # type: ignore pass diff --git a/moto/packages/boto/ec2/instance.py b/moto/packages/boto/ec2/instance.py index 48880ba7f..571c095e2 100644 --- a/moto/packages/boto/ec2/instance.py +++ b/moto/packages/boto/ec2/instance.py @@ -41,12 +41,12 @@ class InstancePlacement: runs on single-tenant hardware. """ - def __init__(self, zone=None, group_name=None, tenancy=None): + def __init__(self, zone: Any = None, group_name: Any = None, tenancy: Any = None): self.zone = zone self.group_name = group_name self.tenancy = tenancy - def __repr__(self): + def __repr__(self) -> Any: return self.zone @@ -62,12 +62,12 @@ class Reservation(EC2Object): Reservation. """ - def __init__(self, reservation_id) -> None: + def __init__(self, reservation_id: Any) -> None: super().__init__(connection=None) self.id = reservation_id self.owner_id = None - self.groups = [] - self.instances = [] + self.groups: Any = [] + self.instances: Any = [] def __repr__(self) -> str: return "Reservation:%s" % self.id @@ -153,9 +153,9 @@ class Instance(TaggedEC2Object): self.group_name = None self.client_token = None self.eventsSet = None - self.groups = [] + self.groups: Any = [] self.platform = None - self.interfaces = [] + self.interfaces: Any = [] self.hypervisor = None self.virtualization_type = None self.architecture = None @@ -164,15 +164,15 @@ class Instance(TaggedEC2Object): self._placement = InstancePlacement() def __repr__(self) -> str: - return "Instance:%s" % self.id + return "Instance:%s" % self.id # type: ignore @property def state(self) -> str: - return self._state.name + return self._state.name # type: ignore @property def state_code(self) -> str: - return self._state.code + return self._state.code # type: ignore @property def placement(self) -> str: diff --git a/moto/packages/boto/ec2/instancetype.py b/moto/packages/boto/ec2/instancetype.py index a84e4879e..0fccc00fb 100644 --- a/moto/packages/boto/ec2/instancetype.py +++ b/moto/packages/boto/ec2/instancetype.py @@ -19,6 +19,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. +# type: ignore from moto.packages.boto.ec2.ec2object import EC2Object diff --git a/moto/packages/boto/ec2/tag.py b/moto/packages/boto/ec2/tag.py index 9f5c2ef88..fa088895e 100644 --- a/moto/packages/boto/ec2/tag.py +++ b/moto/packages/boto/ec2/tag.py @@ -21,7 +21,7 @@ # IN THE SOFTWARE. -class TagSet(dict): +class TagSet(dict): # type: ignore """ A TagSet is used to collect the tags associated with a particular EC2 resource. Not all resources can be tagged but for those that @@ -29,7 +29,7 @@ class TagSet(dict): :class:`boto.ec2.ec2object.TaggedEC2Object` for more details. """ - def __init__(self, connection=None): + def __init__(self, connection=None): # type: ignore self.connection = connection self._current_key = None self._current_value = None diff --git a/moto/packages/cfnresponse/cfnresponse.py b/moto/packages/cfnresponse/cfnresponse.py index 151bc8a21..4be9e11e3 100644 --- a/moto/packages/cfnresponse/cfnresponse.py +++ b/moto/packages/cfnresponse/cfnresponse.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: MIT-0 from __future__ import print_function +from typing import Any import urllib3 import json @@ -15,14 +16,14 @@ http = urllib3.PoolManager() def send( - event, - context, - responseStatus, - responseData, - physicalResourceId=None, - noEcho=False, - reason=None, -): + event: Any, + context: Any, + responseStatus: Any, + responseData: Any, + physicalResourceId: Any = None, + noEcho: bool = False, + reason: Any = None, +) -> None: responseUrl = event["ResponseURL"] print(responseUrl) @@ -49,7 +50,7 @@ def send( headers = {"content-type": "", "content-length": str(len(json_responseBody))} try: - response = http.request( + response = http.request( # type: ignore "PUT", responseUrl, headers=headers, body=json_responseBody ) print("Status code:", response.status) diff --git a/moto/personalize/exceptions.py b/moto/personalize/exceptions.py index 68259f691..bcb41360d 100644 --- a/moto/personalize/exceptions.py +++ b/moto/personalize/exceptions.py @@ -7,7 +7,7 @@ class PersonalizeException(JsonRESTError): class ResourceNotFoundException(PersonalizeException): - def __init__(self, arn): + def __init__(self, arn: str): super().__init__( "ResourceNotFoundException", f"Resource Arn {arn} does not exist." ) diff --git a/moto/personalize/models.py b/moto/personalize/models.py index 410b685b5..3b66371ce 100644 --- a/moto/personalize/models.py +++ b/moto/personalize/models.py @@ -1,4 +1,4 @@ -"""PersonalizeBackend class with methods for supported APIs.""" +from typing import Any, Dict, Iterable from .exceptions import ResourceNotFoundException from moto.core import BaseBackend, BackendDict, BaseModel @@ -6,15 +6,22 @@ from moto.core.utils import unix_time class Schema(BaseModel): - def __init__(self, account_id, region, name, schema, domain): + def __init__( + self, + account_id: str, + region: str, + name: str, + schema: Dict[str, Any], + domain: str, + ): self.name = name self.schema = schema self.domain = domain self.arn = f"arn:aws:personalize:{region}:{account_id}:schema/{name}" self.created = unix_time() - def to_dict(self, full=True): - d = { + def to_dict(self, full: bool = True) -> Dict[str, Any]: + d: Dict[str, Any] = { "name": self.name, "schemaArn": self.arn, "domain": self.domain, @@ -29,32 +36,32 @@ class Schema(BaseModel): class PersonalizeBackend(BaseBackend): """Implementation of Personalize APIs.""" - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.schemas: [str, Schema] = dict() + self.schemas: Dict[str, Schema] = dict() - def create_schema(self, name, schema, domain): + def create_schema(self, name: str, schema_dict: Dict[str, Any], domain: str) -> str: schema = Schema( region=self.region_name, account_id=self.account_id, name=name, - schema=schema, + schema=schema_dict, domain=domain, ) self.schemas[schema.arn] = schema return schema.arn - def delete_schema(self, schema_arn): + def delete_schema(self, schema_arn: str) -> None: if schema_arn not in self.schemas: raise ResourceNotFoundException(schema_arn) self.schemas.pop(schema_arn, None) - def describe_schema(self, schema_arn): + def describe_schema(self, schema_arn: str) -> Schema: if schema_arn not in self.schemas: raise ResourceNotFoundException(schema_arn) return self.schemas[schema_arn] - def list_schemas(self) -> [Schema]: + def list_schemas(self) -> Iterable[Schema]: """ Pagination is not yet implemented """ diff --git a/moto/personalize/responses.py b/moto/personalize/responses.py index ac100b799..f42913ee4 100644 --- a/moto/personalize/responses.py +++ b/moto/personalize/responses.py @@ -2,47 +2,45 @@ import json from moto.core.responses import BaseResponse -from .models import personalize_backends +from .models import personalize_backends, PersonalizeBackend class PersonalizeResponse(BaseResponse): """Handler for Personalize requests and responses.""" - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="personalize") @property - def personalize_backend(self): + def personalize_backend(self) -> PersonalizeBackend: """Return backend instance specific for this region.""" return personalize_backends[self.current_account][self.region] - # add methods from here - - def create_schema(self): + def create_schema(self) -> str: params = json.loads(self.body) name = params.get("name") schema = params.get("schema") domain = params.get("domain") schema_arn = self.personalize_backend.create_schema( name=name, - schema=schema, + schema_dict=schema, domain=domain, ) return json.dumps(dict(schemaArn=schema_arn)) - def delete_schema(self): + def delete_schema(self) -> str: params = json.loads(self.body) schema_arn = params.get("schemaArn") self.personalize_backend.delete_schema(schema_arn=schema_arn) return "{}" - def describe_schema(self): + def describe_schema(self) -> str: params = json.loads(self.body) schema_arn = params.get("schemaArn") schema = self.personalize_backend.describe_schema(schema_arn=schema_arn) return json.dumps(dict(schema=schema.to_dict())) - def list_schemas(self): + def list_schemas(self) -> str: schemas = self.personalize_backend.list_schemas() resp = {"schemas": [s.to_dict(full=False) for s in schemas]} return json.dumps(resp) diff --git a/moto/pinpoint/exceptions.py b/moto/pinpoint/exceptions.py index 4deebe359..817fab862 100644 --- a/moto/pinpoint/exceptions.py +++ b/moto/pinpoint/exceptions.py @@ -9,12 +9,12 @@ class PinpointExceptions(JsonRESTError): class ApplicationNotFound(PinpointExceptions): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("NotFoundException", "Application not found") class EventStreamNotFound(PinpointExceptions): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("NotFoundException", "Resource not found") diff --git a/moto/pinpoint/models.py b/moto/pinpoint/models.py index 5cbdde16f..b3d33669f 100644 --- a/moto/pinpoint/models.py +++ b/moto/pinpoint/models.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any, Dict, List, Iterable, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.core.utils import unix_time from moto.moto_api._internal import mock_random @@ -8,7 +9,7 @@ from .exceptions import ApplicationNotFound, EventStreamNotFound class App(BaseModel): - def __init__(self, account_id, name): + def __init__(self, account_id: str, name: str): self.application_id = str(mock_random.uuid4()).replace("-", "") self.arn = ( f"arn:aws:mobiletargeting:us-east-1:{account_id}:apps/{self.application_id}" @@ -16,30 +17,30 @@ class App(BaseModel): self.name = name self.created = unix_time() self.settings = AppSettings() - self.event_stream = None + self.event_stream: Optional[EventStream] = None - def get_settings(self): + def get_settings(self) -> "AppSettings": return self.settings - def update_settings(self, settings): + def update_settings(self, settings: Dict[str, Any]) -> "AppSettings": self.settings.update(settings) return self.settings - def delete_event_stream(self): + def delete_event_stream(self) -> "EventStream": stream = self.event_stream self.event_stream = None - return stream + return stream # type: ignore - def get_event_stream(self): + def get_event_stream(self) -> "EventStream": if self.event_stream is None: raise EventStreamNotFound() return self.event_stream - def put_event_stream(self, stream_arn, role_arn): + def put_event_stream(self, stream_arn: str, role_arn: str) -> "EventStream": self.event_stream = EventStream(stream_arn, role_arn) return self.event_stream - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "Arn": self.arn, "Id": self.application_id, @@ -49,15 +50,15 @@ class App(BaseModel): class AppSettings(BaseModel): - def __init__(self): - self.settings = dict() - self.last_modified = unix_time() + def __init__(self) -> None: + self.settings: Dict[str, Any] = dict() + self.last_modified = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ") - def update(self, settings): + def update(self, settings: Dict[str, Any]) -> None: self.settings = settings self.last_modified = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ") - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "CampaignHook": self.settings.get("CampaignHook", {}), "CloudWatchMetricsEnabled": self.settings.get( @@ -70,12 +71,12 @@ class AppSettings(BaseModel): class EventStream(BaseModel): - def __init__(self, stream_arn, role_arn): + def __init__(self, stream_arn: str, role_arn: str): self.stream_arn = stream_arn self.role_arn = role_arn self.last_modified = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ") - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "DestinationStreamArn": self.stream_arn, "RoleArn": self.role_arn, @@ -86,62 +87,65 @@ class EventStream(BaseModel): class PinpointBackend(BaseBackend): """Implementation of Pinpoint APIs.""" - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.apps = {} + self.apps: Dict[str, App] = {} self.tagger = TaggingService() - def create_app(self, name, tags): + def create_app(self, name: str, tags: Dict[str, str]) -> App: app = App(self.account_id, name) self.apps[app.application_id] = app - tags = self.tagger.convert_dict_to_tags_input(tags) - self.tagger.tag_resource(app.arn, tags) + tag_list = self.tagger.convert_dict_to_tags_input(tags) + self.tagger.tag_resource(app.arn, tag_list) return app - def delete_app(self, application_id): + def delete_app(self, application_id: str) -> App: self.get_app(application_id) return self.apps.pop(application_id) - def get_app(self, application_id): + def get_app(self, application_id: str) -> App: if application_id not in self.apps: raise ApplicationNotFound() return self.apps[application_id] - def get_apps(self): + def get_apps(self) -> Iterable[App]: """ Pagination is not yet implemented """ return self.apps.values() - def update_application_settings(self, application_id, settings): + def update_application_settings( + self, application_id: str, settings: Dict[str, Any] + ) -> AppSettings: app = self.get_app(application_id) return app.update_settings(settings) - def get_application_settings(self, application_id): + def get_application_settings(self, application_id: str) -> AppSettings: app = self.get_app(application_id) return app.get_settings() - def list_tags_for_resource(self, resource_arn): + def list_tags_for_resource(self, resource_arn: str) -> Dict[str, Dict[str, str]]: tags = self.tagger.get_tag_dict_for_resource(resource_arn) return {"tags": tags} - def tag_resource(self, resource_arn, tags): - tags = TaggingService.convert_dict_to_tags_input(tags) - self.tagger.tag_resource(resource_arn, tags) + def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None: + tag_list = TaggingService.convert_dict_to_tags_input(tags) + self.tagger.tag_resource(resource_arn, tag_list) - def untag_resource(self, resource_arn, tag_keys): + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: self.tagger.untag_resource_using_names(resource_arn, tag_keys) - return - def put_event_stream(self, application_id, stream_arn, role_arn): + def put_event_stream( + self, application_id: str, stream_arn: str, role_arn: str + ) -> EventStream: app = self.get_app(application_id) return app.put_event_stream(stream_arn, role_arn) - def get_event_stream(self, application_id): + def get_event_stream(self, application_id: str) -> EventStream: app = self.get_app(application_id) return app.get_event_stream() - def delete_event_stream(self, application_id): + def delete_event_stream(self, application_id: str) -> EventStream: app = self.get_app(application_id) return app.delete_event_stream() diff --git a/moto/pinpoint/responses.py b/moto/pinpoint/responses.py index 54d4b01e8..5b9f9b861 100644 --- a/moto/pinpoint/responses.py +++ b/moto/pinpoint/responses.py @@ -1,44 +1,46 @@ """Handles incoming pinpoint requests, invokes methods, returns responses.""" import json +from typing import Any +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from urllib.parse import unquote -from .models import pinpoint_backends +from .models import pinpoint_backends, PinpointBackend class PinpointResponse(BaseResponse): """Handler for Pinpoint requests and responses.""" - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="pinpoint") @property - def pinpoint_backend(self): + def pinpoint_backend(self) -> PinpointBackend: """Return backend instance specific for this region.""" return pinpoint_backends[self.current_account][self.region] - def app(self, request, full_url, headers): + def app(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "DELETE": return self.delete_app() if request.method == "GET": return self.get_app() - def apps(self, request, full_url, headers): + def apps(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.get_apps() if request.method == "POST": return self.create_app() - def app_settings(self, request, full_url, headers): + def app_settings(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.get_application_settings() if request.method == "PUT": return self.update_application_settings() - def eventstream(self, request, full_url, headers): + def eventstream(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "DELETE": return self.delete_event_stream() @@ -47,7 +49,7 @@ class PinpointResponse(BaseResponse): if request.method == "POST": return self.put_event_stream() - def tags(self, request, full_url, headers): + def tags(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "DELETE": return self.untag_resource() @@ -56,67 +58,67 @@ class PinpointResponse(BaseResponse): if request.method == "POST": return self.tag_resource() - def create_app(self): + def create_app(self) -> TYPE_RESPONSE: params = json.loads(self.body) name = params.get("Name") tags = params.get("tags", {}) app = self.pinpoint_backend.create_app(name=name, tags=tags) return 201, {}, json.dumps(app.to_json()) - def delete_app(self): + def delete_app(self) -> TYPE_RESPONSE: application_id = self.path.split("/")[-1] app = self.pinpoint_backend.delete_app(application_id=application_id) return 200, {}, json.dumps(app.to_json()) - def get_app(self): + def get_app(self) -> TYPE_RESPONSE: application_id = self.path.split("/")[-1] app = self.pinpoint_backend.get_app(application_id=application_id) return 200, {}, json.dumps(app.to_json()) - def get_apps(self): + def get_apps(self) -> TYPE_RESPONSE: apps = self.pinpoint_backend.get_apps() resp = {"Item": [a.to_json() for a in apps]} return 200, {}, json.dumps(resp) - def update_application_settings(self): + def update_application_settings(self) -> TYPE_RESPONSE: application_id = self.path.split("/")[-2] settings = json.loads(self.body) app_settings = self.pinpoint_backend.update_application_settings( application_id=application_id, settings=settings ) - app_settings = app_settings.to_json() - app_settings["ApplicationId"] = application_id - return 200, {}, json.dumps(app_settings) + response = app_settings.to_json() + response["ApplicationId"] = application_id + return 200, {}, json.dumps(response) - def get_application_settings(self): + def get_application_settings(self) -> TYPE_RESPONSE: application_id = self.path.split("/")[-2] app_settings = self.pinpoint_backend.get_application_settings( application_id=application_id ) - app_settings = app_settings.to_json() - app_settings["ApplicationId"] = application_id - return 200, {}, json.dumps(app_settings) + response = app_settings.to_json() + response["ApplicationId"] = application_id + return 200, {}, json.dumps(response) - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> TYPE_RESPONSE: resource_arn = unquote(self.path).split("/tags/")[-1] tags = self.pinpoint_backend.list_tags_for_resource(resource_arn=resource_arn) return 200, {}, json.dumps(tags) - def tag_resource(self): + def tag_resource(self) -> TYPE_RESPONSE: resource_arn = unquote(self.path).split("/tags/")[-1] tags = json.loads(self.body).get("tags", {}) self.pinpoint_backend.tag_resource(resource_arn=resource_arn, tags=tags) return 200, {}, "{}" - def untag_resource(self): + def untag_resource(self) -> TYPE_RESPONSE: resource_arn = unquote(self.path).split("/tags/")[-1] tag_keys = self.querystring.get("tagKeys") self.pinpoint_backend.untag_resource( - resource_arn=resource_arn, tag_keys=tag_keys + resource_arn=resource_arn, tag_keys=tag_keys # type: ignore[arg-type] ) return 200, {}, "{}" - def put_event_stream(self): + def put_event_stream(self) -> TYPE_RESPONSE: application_id = self.path.split("/")[-2] params = json.loads(self.body) stream_arn = params.get("DestinationStreamArn") @@ -128,7 +130,7 @@ class PinpointResponse(BaseResponse): resp["ApplicationId"] = application_id return 200, {}, json.dumps(resp) - def get_event_stream(self): + def get_event_stream(self) -> TYPE_RESPONSE: application_id = self.path.split("/")[-2] event_stream = self.pinpoint_backend.get_event_stream( application_id=application_id @@ -137,7 +139,7 @@ class PinpointResponse(BaseResponse): resp["ApplicationId"] = application_id return 200, {}, json.dumps(resp) - def delete_event_stream(self): + def delete_event_stream(self) -> TYPE_RESPONSE: application_id = self.path.split("/")[-2] event_stream = self.pinpoint_backend.delete_event_stream( application_id=application_id diff --git a/moto/polly/models.py b/moto/polly/models.py index d28abd409..0bf5b75b2 100644 --- a/moto/polly/models.py +++ b/moto/polly/models.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, List, Optional from xml.etree import ElementTree as ET import datetime @@ -8,7 +9,7 @@ from .utils import make_arn_for_lexicon class Lexicon(BaseModel): - def __init__(self, name, content, account_id, region_name): + def __init__(self, name: str, content: str, account_id: str, region_name: str): self.name = name self.content = content self.size = 0 @@ -20,7 +21,7 @@ class Lexicon(BaseModel): self.update() - def update(self, content=None): + def update(self, content: Optional[str] = None) -> None: if content is not None: self.content = content @@ -28,7 +29,7 @@ class Lexicon(BaseModel): try: root = ET.fromstring(self.content) self.size = len(self.content) - self.last_modified = int( + self.last_modified = int( # type: ignore ( datetime.datetime.now() - datetime.datetime(1970, 1, 1) ).total_seconds() @@ -37,14 +38,14 @@ class Lexicon(BaseModel): for key, value in root.attrib.items(): if key.endswith("alphabet"): - self.alphabet = value + self.alphabet = value # type: ignore elif key.endswith("lang"): - self.language_code = value + self.language_code = value # type: ignore except Exception as err: raise ValueError(f"Failure parsing XML: {err}") - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "Attributes": { "Alphabet": self.alphabet, @@ -56,16 +57,16 @@ class Lexicon(BaseModel): } } - def __repr__(self): + def __repr__(self) -> str: return f"" class PollyBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._lexicons = {} + self._lexicons: Dict[str, Lexicon] = {} - def describe_voices(self, language_code): + def describe_voices(self, language_code: str) -> List[Dict[str, Any]]: """ Pagination is not yet implemented """ @@ -74,15 +75,15 @@ class PollyBackend(BaseBackend): return [item for item in VOICE_DATA if item["LanguageCode"] == language_code] - def delete_lexicon(self, name): + def delete_lexicon(self, name: str) -> None: # implement here del self._lexicons[name] - def get_lexicon(self, name): + def get_lexicon(self, name: str) -> Lexicon: # Raises KeyError return self._lexicons[name] - def list_lexicons(self): + def list_lexicons(self) -> List[Dict[str, Any]]: """ Pagination is not yet implemented """ @@ -97,12 +98,12 @@ class PollyBackend(BaseBackend): return result - def put_lexicon(self, name, content): + def put_lexicon(self, name: str, content: str) -> None: # If lexicon content is bad, it will raise ValueError if name in self._lexicons: # Regenerated all the stats from the XML # but keeps the ARN - self._lexicons.update(content) + self._lexicons[name].update(content) else: lexicon = Lexicon( name, content, self.account_id, region_name=self.region_name diff --git a/moto/polly/responses.py b/moto/polly/responses.py index 6b4532ac3..0a73a3af0 100644 --- a/moto/polly/responses.py +++ b/moto/polly/responses.py @@ -1,33 +1,33 @@ import json import re - +from typing import Any, Dict, Tuple, Union from urllib.parse import urlsplit from moto.core.responses import BaseResponse -from .models import polly_backends +from .models import polly_backends, PollyBackend from .resources import LANGUAGE_CODES, VOICE_IDS LEXICON_NAME_REGEX = re.compile(r"^[0-9A-Za-z]{1,20}$") class PollyResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="polly") @property - def polly_backend(self): + def polly_backend(self) -> PollyBackend: return polly_backends[self.current_account][self.region] @property - def json(self): + def json(self) -> Dict[str, Any]: # type: ignore[misc] if not hasattr(self, "_json"): self._json = json.loads(self.body) return self._json - def _error(self, code, message): + def _error(self, code: str, message: str) -> Tuple[str, Dict[str, int]]: return json.dumps({"__type": code, "message": message}), dict(status=400) - def _get_action(self): + def _get_action(self) -> str: # Amazon is now naming things /v1/api_name url_parts = urlsplit(self.uri).path.lstrip("/").split("/") # [0] = 'v1' @@ -35,11 +35,11 @@ class PollyResponse(BaseResponse): return url_parts[1] # DescribeVoices - def voices(self): + def voices(self) -> Union[str, Tuple[str, Dict[str, int]]]: language_code = self._get_param("LanguageCode") if language_code is not None and language_code not in LANGUAGE_CODES: - all_codes = ", ".join(LANGUAGE_CODES) + all_codes = ", ".join(LANGUAGE_CODES) # type: ignore msg = ( f"1 validation error detected: Value '{language_code}' at 'languageCode' failed to satisfy constraint: " f"Member must satisfy enum value set: [{all_codes}]" @@ -50,7 +50,7 @@ class PollyResponse(BaseResponse): return json.dumps({"Voices": voices}) - def lexicons(self): + def lexicons(self) -> Union[str, Tuple[str, Dict[str, int]]]: # Dish out requests based on methods # anything after the /v1/lexicons/ @@ -69,7 +69,9 @@ class PollyResponse(BaseResponse): return self._error("InvalidAction", "Bad route") # PutLexicon - def _put_lexicons(self, lexicon_name): + def _put_lexicons( + self, lexicon_name: str + ) -> Union[str, Tuple[str, Dict[str, int]]]: if LEXICON_NAME_REGEX.match(lexicon_name) is None: return self._error( "InvalidParameterValue", "Lexicon name must match [0-9A-Za-z]{1,20}" @@ -83,13 +85,13 @@ class PollyResponse(BaseResponse): return "" # ListLexicons - def _get_lexicons_list(self): + def _get_lexicons_list(self) -> str: result = {"Lexicons": self.polly_backend.list_lexicons()} return json.dumps(result) # GetLexicon - def _get_lexicon(self, lexicon_name): + def _get_lexicon(self, lexicon_name: str) -> Union[str, Tuple[str, Dict[str, int]]]: try: lexicon = self.polly_backend.get_lexicon(lexicon_name) except KeyError: @@ -103,7 +105,9 @@ class PollyResponse(BaseResponse): return json.dumps(result) # DeleteLexicon - def _delete_lexicon(self, lexicon_name): + def _delete_lexicon( + self, lexicon_name: str + ) -> Union[str, Tuple[str, Dict[str, int]]]: try: self.polly_backend.delete_lexicon(lexicon_name) except KeyError: @@ -112,7 +116,7 @@ class PollyResponse(BaseResponse): return "" # SynthesizeSpeech - def speech(self): + def speech(self) -> Tuple[str, Dict[str, Any]]: # Sanity check params args = { "lexicon_names": None, @@ -169,12 +173,12 @@ class PollyResponse(BaseResponse): if "VoiceId" not in self.json: return self._error("MissingParameter", "Missing parameter VoiceId") if self.json["VoiceId"] not in VOICE_IDS: - all_voices = ", ".join(VOICE_IDS) + all_voices = ", ".join(VOICE_IDS) # type: ignore return self._error("InvalidParameterValue", f"Not one of {all_voices}") args["voice_id"] = self.json["VoiceId"] # More validation - if len(args["text"]) > 3000: + if len(args["text"]) > 3000: # type: ignore return self._error("TextLengthExceededException", "Text too long") if args["speech_marks"] is not None and args["output_format"] != "json": diff --git a/moto/polly/utils.py b/moto/polly/utils.py index a0de8713a..e946674b2 100644 --- a/moto/polly/utils.py +++ b/moto/polly/utils.py @@ -1,2 +1,2 @@ -def make_arn_for_lexicon(account_id, name, region_name): +def make_arn_for_lexicon(account_id: str, name: str, region_name: str) -> str: return f"arn:aws:polly:{region_name}:{account_id}:lexicon/{name}" diff --git a/setup.cfg b/setup.cfg index 679b85634..a362aea2a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -235,7 +235,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/rdsdata +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/rdsdata show_column_numbers=True show_error_codes = True disable_error_code=abstract diff --git a/tests/test_polly/test_polly.py b/tests/test_polly/test_polly.py index 0a48d1aba..1700dfeb9 100644 --- a/tests/test_polly/test_polly.py +++ b/tests/test_polly/test_polly.py @@ -258,3 +258,10 @@ def test_synthesize_speech_bad_speech_marks2(): ) else: raise RuntimeError("Should have raised ") + + +@mock_polly +def test_update_lexicon(): + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) + client.put_lexicon(Name="test", Content=LEXICON_XML)