Techdebt: MyPy P (#6189)

This commit is contained in:
Bert Blommers 2023-04-08 20:44:26 +00:00 committed by GitHub
parent ba05cc9c81
commit 1eb3479d08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 173 additions and 147 deletions

View File

@ -63,7 +63,7 @@ class BlockDeviceType(object):
EBSBlockDeviceType = BlockDeviceType EBSBlockDeviceType = BlockDeviceType
class BlockDeviceMapping(dict): class BlockDeviceMapping(Dict[Any, Any]):
""" """
Represents a collection of BlockDeviceTypes when creating ec2 instances. Represents a collection of BlockDeviceTypes when creating ec2 instances.

View File

@ -23,11 +23,12 @@
""" """
Represents an EC2 Object Represents an EC2 Object
""" """
from typing import Any
from moto.packages.boto.ec2.tag import TagSet from moto.packages.boto.ec2.tag import TagSet
class EC2Object(object): class EC2Object:
def __init__(self, connection=None): def __init__(self, connection: Any = None):
self.connection = connection self.connection = connection
self.region = None self.region = None
@ -43,6 +44,6 @@ class TaggedEC2Object(EC2Object):
object. object.
""" """
def __init__(self, connection=None): def __init__(self, connection: Any = None):
super(TaggedEC2Object, self).__init__(connection) super(TaggedEC2Object, self).__init__(connection)
self.tags = TagSet() self.tags = TagSet() # type: ignore

View File

@ -21,5 +21,5 @@
# IN THE SOFTWARE. # IN THE SOFTWARE.
class ProductCodes(list): class ProductCodes(list): # type: ignore
pass pass

View File

@ -41,12 +41,12 @@ class InstancePlacement:
runs on single-tenant hardware. runs on single-tenant hardware.
""" """
def __init__(self, zone=None, group_name=None, tenancy=None): def __init__(self, zone: Any = None, group_name: Any = None, tenancy: Any = None):
self.zone = zone self.zone = zone
self.group_name = group_name self.group_name = group_name
self.tenancy = tenancy self.tenancy = tenancy
def __repr__(self): def __repr__(self) -> Any:
return self.zone return self.zone
@ -62,12 +62,12 @@ class Reservation(EC2Object):
Reservation. Reservation.
""" """
def __init__(self, reservation_id) -> None: def __init__(self, reservation_id: Any) -> None:
super().__init__(connection=None) super().__init__(connection=None)
self.id = reservation_id self.id = reservation_id
self.owner_id = None self.owner_id = None
self.groups = [] self.groups: Any = []
self.instances = [] self.instances: Any = []
def __repr__(self) -> str: def __repr__(self) -> str:
return "Reservation:%s" % self.id return "Reservation:%s" % self.id
@ -153,9 +153,9 @@ class Instance(TaggedEC2Object):
self.group_name = None self.group_name = None
self.client_token = None self.client_token = None
self.eventsSet = None self.eventsSet = None
self.groups = [] self.groups: Any = []
self.platform = None self.platform = None
self.interfaces = [] self.interfaces: Any = []
self.hypervisor = None self.hypervisor = None
self.virtualization_type = None self.virtualization_type = None
self.architecture = None self.architecture = None
@ -164,15 +164,15 @@ class Instance(TaggedEC2Object):
self._placement = InstancePlacement() self._placement = InstancePlacement()
def __repr__(self) -> str: def __repr__(self) -> str:
return "Instance:%s" % self.id return "Instance:%s" % self.id # type: ignore
@property @property
def state(self) -> str: def state(self) -> str:
return self._state.name return self._state.name # type: ignore
@property @property
def state_code(self) -> str: def state_code(self) -> str:
return self._state.code return self._state.code # type: ignore
@property @property
def placement(self) -> str: def placement(self) -> str:

View File

@ -19,6 +19,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE. # IN THE SOFTWARE.
# type: ignore
from moto.packages.boto.ec2.ec2object import EC2Object from moto.packages.boto.ec2.ec2object import EC2Object

View File

@ -21,7 +21,7 @@
# IN THE SOFTWARE. # IN THE SOFTWARE.
class TagSet(dict): class TagSet(dict): # type: ignore
""" """
A TagSet is used to collect the tags associated with a particular A TagSet is used to collect the tags associated with a particular
EC2 resource. Not all resources can be tagged but for those that EC2 resource. Not all resources can be tagged but for those that
@ -29,7 +29,7 @@ class TagSet(dict):
:class:`boto.ec2.ec2object.TaggedEC2Object` for more details. :class:`boto.ec2.ec2object.TaggedEC2Object` for more details.
""" """
def __init__(self, connection=None): def __init__(self, connection=None): # type: ignore
self.connection = connection self.connection = connection
self._current_key = None self._current_key = None
self._current_value = None self._current_value = None

View File

@ -5,6 +5,7 @@
# SPDX-License-Identifier: MIT-0 # SPDX-License-Identifier: MIT-0
from __future__ import print_function from __future__ import print_function
from typing import Any
import urllib3 import urllib3
import json import json
@ -15,14 +16,14 @@ http = urllib3.PoolManager()
def send( def send(
event, event: Any,
context, context: Any,
responseStatus, responseStatus: Any,
responseData, responseData: Any,
physicalResourceId=None, physicalResourceId: Any = None,
noEcho=False, noEcho: bool = False,
reason=None, reason: Any = None,
): ) -> None:
responseUrl = event["ResponseURL"] responseUrl = event["ResponseURL"]
print(responseUrl) print(responseUrl)
@ -49,7 +50,7 @@ def send(
headers = {"content-type": "", "content-length": str(len(json_responseBody))} headers = {"content-type": "", "content-length": str(len(json_responseBody))}
try: try:
response = http.request( response = http.request( # type: ignore
"PUT", responseUrl, headers=headers, body=json_responseBody "PUT", responseUrl, headers=headers, body=json_responseBody
) )
print("Status code:", response.status) print("Status code:", response.status)

View File

@ -7,7 +7,7 @@ class PersonalizeException(JsonRESTError):
class ResourceNotFoundException(PersonalizeException): class ResourceNotFoundException(PersonalizeException):
def __init__(self, arn): def __init__(self, arn: str):
super().__init__( super().__init__(
"ResourceNotFoundException", f"Resource Arn {arn} does not exist." "ResourceNotFoundException", f"Resource Arn {arn} does not exist."
) )

View File

@ -1,4 +1,4 @@
"""PersonalizeBackend class with methods for supported APIs.""" from typing import Any, Dict, Iterable
from .exceptions import ResourceNotFoundException from .exceptions import ResourceNotFoundException
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
@ -6,15 +6,22 @@ from moto.core.utils import unix_time
class Schema(BaseModel): class Schema(BaseModel):
def __init__(self, account_id, region, name, schema, domain): def __init__(
self,
account_id: str,
region: str,
name: str,
schema: Dict[str, Any],
domain: str,
):
self.name = name self.name = name
self.schema = schema self.schema = schema
self.domain = domain self.domain = domain
self.arn = f"arn:aws:personalize:{region}:{account_id}:schema/{name}" self.arn = f"arn:aws:personalize:{region}:{account_id}:schema/{name}"
self.created = unix_time() self.created = unix_time()
def to_dict(self, full=True): def to_dict(self, full: bool = True) -> Dict[str, Any]:
d = { d: Dict[str, Any] = {
"name": self.name, "name": self.name,
"schemaArn": self.arn, "schemaArn": self.arn,
"domain": self.domain, "domain": self.domain,
@ -29,32 +36,32 @@ class Schema(BaseModel):
class PersonalizeBackend(BaseBackend): class PersonalizeBackend(BaseBackend):
"""Implementation of Personalize APIs.""" """Implementation of Personalize APIs."""
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.schemas: [str, Schema] = dict() self.schemas: Dict[str, Schema] = dict()
def create_schema(self, name, schema, domain): def create_schema(self, name: str, schema_dict: Dict[str, Any], domain: str) -> str:
schema = Schema( schema = Schema(
region=self.region_name, region=self.region_name,
account_id=self.account_id, account_id=self.account_id,
name=name, name=name,
schema=schema, schema=schema_dict,
domain=domain, domain=domain,
) )
self.schemas[schema.arn] = schema self.schemas[schema.arn] = schema
return schema.arn return schema.arn
def delete_schema(self, schema_arn): def delete_schema(self, schema_arn: str) -> None:
if schema_arn not in self.schemas: if schema_arn not in self.schemas:
raise ResourceNotFoundException(schema_arn) raise ResourceNotFoundException(schema_arn)
self.schemas.pop(schema_arn, None) self.schemas.pop(schema_arn, None)
def describe_schema(self, schema_arn): def describe_schema(self, schema_arn: str) -> Schema:
if schema_arn not in self.schemas: if schema_arn not in self.schemas:
raise ResourceNotFoundException(schema_arn) raise ResourceNotFoundException(schema_arn)
return self.schemas[schema_arn] return self.schemas[schema_arn]
def list_schemas(self) -> [Schema]: def list_schemas(self) -> Iterable[Schema]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """

View File

@ -2,47 +2,45 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import personalize_backends from .models import personalize_backends, PersonalizeBackend
class PersonalizeResponse(BaseResponse): class PersonalizeResponse(BaseResponse):
"""Handler for Personalize requests and responses.""" """Handler for Personalize requests and responses."""
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="personalize") super().__init__(service_name="personalize")
@property @property
def personalize_backend(self): def personalize_backend(self) -> PersonalizeBackend:
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return personalize_backends[self.current_account][self.region] return personalize_backends[self.current_account][self.region]
# add methods from here def create_schema(self) -> str:
def create_schema(self):
params = json.loads(self.body) params = json.loads(self.body)
name = params.get("name") name = params.get("name")
schema = params.get("schema") schema = params.get("schema")
domain = params.get("domain") domain = params.get("domain")
schema_arn = self.personalize_backend.create_schema( schema_arn = self.personalize_backend.create_schema(
name=name, name=name,
schema=schema, schema_dict=schema,
domain=domain, domain=domain,
) )
return json.dumps(dict(schemaArn=schema_arn)) return json.dumps(dict(schemaArn=schema_arn))
def delete_schema(self): def delete_schema(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
schema_arn = params.get("schemaArn") schema_arn = params.get("schemaArn")
self.personalize_backend.delete_schema(schema_arn=schema_arn) self.personalize_backend.delete_schema(schema_arn=schema_arn)
return "{}" return "{}"
def describe_schema(self): def describe_schema(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
schema_arn = params.get("schemaArn") schema_arn = params.get("schemaArn")
schema = self.personalize_backend.describe_schema(schema_arn=schema_arn) schema = self.personalize_backend.describe_schema(schema_arn=schema_arn)
return json.dumps(dict(schema=schema.to_dict())) return json.dumps(dict(schema=schema.to_dict()))
def list_schemas(self): def list_schemas(self) -> str:
schemas = self.personalize_backend.list_schemas() schemas = self.personalize_backend.list_schemas()
resp = {"schemas": [s.to_dict(full=False) for s in schemas]} resp = {"schemas": [s.to_dict(full=False) for s in schemas]}
return json.dumps(resp) return json.dumps(resp)

View File

@ -9,12 +9,12 @@ class PinpointExceptions(JsonRESTError):
class ApplicationNotFound(PinpointExceptions): class ApplicationNotFound(PinpointExceptions):
code = 404 code = 404
def __init__(self): def __init__(self) -> None:
super().__init__("NotFoundException", "Application not found") super().__init__("NotFoundException", "Application not found")
class EventStreamNotFound(PinpointExceptions): class EventStreamNotFound(PinpointExceptions):
code = 404 code = 404
def __init__(self): def __init__(self) -> None:
super().__init__("NotFoundException", "Resource not found") super().__init__("NotFoundException", "Resource not found")

View File

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Iterable, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
@ -8,7 +9,7 @@ from .exceptions import ApplicationNotFound, EventStreamNotFound
class App(BaseModel): class App(BaseModel):
def __init__(self, account_id, name): def __init__(self, account_id: str, name: str):
self.application_id = str(mock_random.uuid4()).replace("-", "") self.application_id = str(mock_random.uuid4()).replace("-", "")
self.arn = ( self.arn = (
f"arn:aws:mobiletargeting:us-east-1:{account_id}:apps/{self.application_id}" f"arn:aws:mobiletargeting:us-east-1:{account_id}:apps/{self.application_id}"
@ -16,30 +17,30 @@ class App(BaseModel):
self.name = name self.name = name
self.created = unix_time() self.created = unix_time()
self.settings = AppSettings() self.settings = AppSettings()
self.event_stream = None self.event_stream: Optional[EventStream] = None
def get_settings(self): def get_settings(self) -> "AppSettings":
return self.settings return self.settings
def update_settings(self, settings): def update_settings(self, settings: Dict[str, Any]) -> "AppSettings":
self.settings.update(settings) self.settings.update(settings)
return self.settings return self.settings
def delete_event_stream(self): def delete_event_stream(self) -> "EventStream":
stream = self.event_stream stream = self.event_stream
self.event_stream = None self.event_stream = None
return stream return stream # type: ignore
def get_event_stream(self): def get_event_stream(self) -> "EventStream":
if self.event_stream is None: if self.event_stream is None:
raise EventStreamNotFound() raise EventStreamNotFound()
return self.event_stream return self.event_stream
def put_event_stream(self, stream_arn, role_arn): def put_event_stream(self, stream_arn: str, role_arn: str) -> "EventStream":
self.event_stream = EventStream(stream_arn, role_arn) self.event_stream = EventStream(stream_arn, role_arn)
return self.event_stream return self.event_stream
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"Id": self.application_id, "Id": self.application_id,
@ -49,15 +50,15 @@ class App(BaseModel):
class AppSettings(BaseModel): class AppSettings(BaseModel):
def __init__(self): def __init__(self) -> None:
self.settings = dict() self.settings: Dict[str, Any] = dict()
self.last_modified = unix_time() self.last_modified = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def update(self, settings): def update(self, settings: Dict[str, Any]) -> None:
self.settings = settings self.settings = settings
self.last_modified = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ") self.last_modified = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"CampaignHook": self.settings.get("CampaignHook", {}), "CampaignHook": self.settings.get("CampaignHook", {}),
"CloudWatchMetricsEnabled": self.settings.get( "CloudWatchMetricsEnabled": self.settings.get(
@ -70,12 +71,12 @@ class AppSettings(BaseModel):
class EventStream(BaseModel): class EventStream(BaseModel):
def __init__(self, stream_arn, role_arn): def __init__(self, stream_arn: str, role_arn: str):
self.stream_arn = stream_arn self.stream_arn = stream_arn
self.role_arn = role_arn self.role_arn = role_arn
self.last_modified = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ") self.last_modified = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"DestinationStreamArn": self.stream_arn, "DestinationStreamArn": self.stream_arn,
"RoleArn": self.role_arn, "RoleArn": self.role_arn,
@ -86,62 +87,65 @@ class EventStream(BaseModel):
class PinpointBackend(BaseBackend): class PinpointBackend(BaseBackend):
"""Implementation of Pinpoint APIs.""" """Implementation of Pinpoint APIs."""
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.apps = {} self.apps: Dict[str, App] = {}
self.tagger = TaggingService() self.tagger = TaggingService()
def create_app(self, name, tags): def create_app(self, name: str, tags: Dict[str, str]) -> App:
app = App(self.account_id, name) app = App(self.account_id, name)
self.apps[app.application_id] = app self.apps[app.application_id] = app
tags = self.tagger.convert_dict_to_tags_input(tags) tag_list = self.tagger.convert_dict_to_tags_input(tags)
self.tagger.tag_resource(app.arn, tags) self.tagger.tag_resource(app.arn, tag_list)
return app return app
def delete_app(self, application_id): def delete_app(self, application_id: str) -> App:
self.get_app(application_id) self.get_app(application_id)
return self.apps.pop(application_id) return self.apps.pop(application_id)
def get_app(self, application_id): def get_app(self, application_id: str) -> App:
if application_id not in self.apps: if application_id not in self.apps:
raise ApplicationNotFound() raise ApplicationNotFound()
return self.apps[application_id] return self.apps[application_id]
def get_apps(self): def get_apps(self) -> Iterable[App]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
return self.apps.values() return self.apps.values()
def update_application_settings(self, application_id, settings): def update_application_settings(
self, application_id: str, settings: Dict[str, Any]
) -> AppSettings:
app = self.get_app(application_id) app = self.get_app(application_id)
return app.update_settings(settings) return app.update_settings(settings)
def get_application_settings(self, application_id): def get_application_settings(self, application_id: str) -> AppSettings:
app = self.get_app(application_id) app = self.get_app(application_id)
return app.get_settings() return app.get_settings()
def list_tags_for_resource(self, resource_arn): def list_tags_for_resource(self, resource_arn: str) -> Dict[str, Dict[str, str]]:
tags = self.tagger.get_tag_dict_for_resource(resource_arn) tags = self.tagger.get_tag_dict_for_resource(resource_arn)
return {"tags": tags} return {"tags": tags}
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None:
tags = TaggingService.convert_dict_to_tags_input(tags) tag_list = TaggingService.convert_dict_to_tags_input(tags)
self.tagger.tag_resource(resource_arn, tags) self.tagger.tag_resource(resource_arn, tag_list)
def untag_resource(self, resource_arn, tag_keys): def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys) self.tagger.untag_resource_using_names(resource_arn, tag_keys)
return
def put_event_stream(self, application_id, stream_arn, role_arn): def put_event_stream(
self, application_id: str, stream_arn: str, role_arn: str
) -> EventStream:
app = self.get_app(application_id) app = self.get_app(application_id)
return app.put_event_stream(stream_arn, role_arn) return app.put_event_stream(stream_arn, role_arn)
def get_event_stream(self, application_id): def get_event_stream(self, application_id: str) -> EventStream:
app = self.get_app(application_id) app = self.get_app(application_id)
return app.get_event_stream() return app.get_event_stream()
def delete_event_stream(self, application_id): def delete_event_stream(self, application_id: str) -> EventStream:
app = self.get_app(application_id) app = self.get_app(application_id)
return app.delete_event_stream() return app.delete_event_stream()

View File

@ -1,44 +1,46 @@
"""Handles incoming pinpoint requests, invokes methods, returns responses.""" """Handles incoming pinpoint requests, invokes methods, returns responses."""
import json import json
from typing import Any
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from urllib.parse import unquote from urllib.parse import unquote
from .models import pinpoint_backends from .models import pinpoint_backends, PinpointBackend
class PinpointResponse(BaseResponse): class PinpointResponse(BaseResponse):
"""Handler for Pinpoint requests and responses.""" """Handler for Pinpoint requests and responses."""
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="pinpoint") super().__init__(service_name="pinpoint")
@property @property
def pinpoint_backend(self): def pinpoint_backend(self) -> PinpointBackend:
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return pinpoint_backends[self.current_account][self.region] return pinpoint_backends[self.current_account][self.region]
def app(self, request, full_url, headers): def app(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_app() return self.delete_app()
if request.method == "GET": if request.method == "GET":
return self.get_app() return self.get_app()
def apps(self, request, full_url, headers): def apps(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.get_apps() return self.get_apps()
if request.method == "POST": if request.method == "POST":
return self.create_app() return self.create_app()
def app_settings(self, request, full_url, headers): def app_settings(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.get_application_settings() return self.get_application_settings()
if request.method == "PUT": if request.method == "PUT":
return self.update_application_settings() return self.update_application_settings()
def eventstream(self, request, full_url, headers): def eventstream(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_event_stream() return self.delete_event_stream()
@ -47,7 +49,7 @@ class PinpointResponse(BaseResponse):
if request.method == "POST": if request.method == "POST":
return self.put_event_stream() return self.put_event_stream()
def tags(self, request, full_url, headers): def tags(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "DELETE": if request.method == "DELETE":
return self.untag_resource() return self.untag_resource()
@ -56,67 +58,67 @@ class PinpointResponse(BaseResponse):
if request.method == "POST": if request.method == "POST":
return self.tag_resource() return self.tag_resource()
def create_app(self): def create_app(self) -> TYPE_RESPONSE:
params = json.loads(self.body) params = json.loads(self.body)
name = params.get("Name") name = params.get("Name")
tags = params.get("tags", {}) tags = params.get("tags", {})
app = self.pinpoint_backend.create_app(name=name, tags=tags) app = self.pinpoint_backend.create_app(name=name, tags=tags)
return 201, {}, json.dumps(app.to_json()) return 201, {}, json.dumps(app.to_json())
def delete_app(self): def delete_app(self) -> TYPE_RESPONSE:
application_id = self.path.split("/")[-1] application_id = self.path.split("/")[-1]
app = self.pinpoint_backend.delete_app(application_id=application_id) app = self.pinpoint_backend.delete_app(application_id=application_id)
return 200, {}, json.dumps(app.to_json()) return 200, {}, json.dumps(app.to_json())
def get_app(self): def get_app(self) -> TYPE_RESPONSE:
application_id = self.path.split("/")[-1] application_id = self.path.split("/")[-1]
app = self.pinpoint_backend.get_app(application_id=application_id) app = self.pinpoint_backend.get_app(application_id=application_id)
return 200, {}, json.dumps(app.to_json()) return 200, {}, json.dumps(app.to_json())
def get_apps(self): def get_apps(self) -> TYPE_RESPONSE:
apps = self.pinpoint_backend.get_apps() apps = self.pinpoint_backend.get_apps()
resp = {"Item": [a.to_json() for a in apps]} resp = {"Item": [a.to_json() for a in apps]}
return 200, {}, json.dumps(resp) return 200, {}, json.dumps(resp)
def update_application_settings(self): def update_application_settings(self) -> TYPE_RESPONSE:
application_id = self.path.split("/")[-2] application_id = self.path.split("/")[-2]
settings = json.loads(self.body) settings = json.loads(self.body)
app_settings = self.pinpoint_backend.update_application_settings( app_settings = self.pinpoint_backend.update_application_settings(
application_id=application_id, settings=settings application_id=application_id, settings=settings
) )
app_settings = app_settings.to_json() response = app_settings.to_json()
app_settings["ApplicationId"] = application_id response["ApplicationId"] = application_id
return 200, {}, json.dumps(app_settings) return 200, {}, json.dumps(response)
def get_application_settings(self): def get_application_settings(self) -> TYPE_RESPONSE:
application_id = self.path.split("/")[-2] application_id = self.path.split("/")[-2]
app_settings = self.pinpoint_backend.get_application_settings( app_settings = self.pinpoint_backend.get_application_settings(
application_id=application_id application_id=application_id
) )
app_settings = app_settings.to_json() response = app_settings.to_json()
app_settings["ApplicationId"] = application_id response["ApplicationId"] = application_id
return 200, {}, json.dumps(app_settings) return 200, {}, json.dumps(response)
def list_tags_for_resource(self): def list_tags_for_resource(self) -> TYPE_RESPONSE:
resource_arn = unquote(self.path).split("/tags/")[-1] resource_arn = unquote(self.path).split("/tags/")[-1]
tags = self.pinpoint_backend.list_tags_for_resource(resource_arn=resource_arn) tags = self.pinpoint_backend.list_tags_for_resource(resource_arn=resource_arn)
return 200, {}, json.dumps(tags) return 200, {}, json.dumps(tags)
def tag_resource(self): def tag_resource(self) -> TYPE_RESPONSE:
resource_arn = unquote(self.path).split("/tags/")[-1] resource_arn = unquote(self.path).split("/tags/")[-1]
tags = json.loads(self.body).get("tags", {}) tags = json.loads(self.body).get("tags", {})
self.pinpoint_backend.tag_resource(resource_arn=resource_arn, tags=tags) self.pinpoint_backend.tag_resource(resource_arn=resource_arn, tags=tags)
return 200, {}, "{}" return 200, {}, "{}"
def untag_resource(self): def untag_resource(self) -> TYPE_RESPONSE:
resource_arn = unquote(self.path).split("/tags/")[-1] resource_arn = unquote(self.path).split("/tags/")[-1]
tag_keys = self.querystring.get("tagKeys") tag_keys = self.querystring.get("tagKeys")
self.pinpoint_backend.untag_resource( self.pinpoint_backend.untag_resource(
resource_arn=resource_arn, tag_keys=tag_keys resource_arn=resource_arn, tag_keys=tag_keys # type: ignore[arg-type]
) )
return 200, {}, "{}" return 200, {}, "{}"
def put_event_stream(self): def put_event_stream(self) -> TYPE_RESPONSE:
application_id = self.path.split("/")[-2] application_id = self.path.split("/")[-2]
params = json.loads(self.body) params = json.loads(self.body)
stream_arn = params.get("DestinationStreamArn") stream_arn = params.get("DestinationStreamArn")
@ -128,7 +130,7 @@ class PinpointResponse(BaseResponse):
resp["ApplicationId"] = application_id resp["ApplicationId"] = application_id
return 200, {}, json.dumps(resp) return 200, {}, json.dumps(resp)
def get_event_stream(self): def get_event_stream(self) -> TYPE_RESPONSE:
application_id = self.path.split("/")[-2] application_id = self.path.split("/")[-2]
event_stream = self.pinpoint_backend.get_event_stream( event_stream = self.pinpoint_backend.get_event_stream(
application_id=application_id application_id=application_id
@ -137,7 +139,7 @@ class PinpointResponse(BaseResponse):
resp["ApplicationId"] = application_id resp["ApplicationId"] = application_id
return 200, {}, json.dumps(resp) return 200, {}, json.dumps(resp)
def delete_event_stream(self): def delete_event_stream(self) -> TYPE_RESPONSE:
application_id = self.path.split("/")[-2] application_id = self.path.split("/")[-2]
event_stream = self.pinpoint_backend.delete_event_stream( event_stream = self.pinpoint_backend.delete_event_stream(
application_id=application_id application_id=application_id

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
import datetime import datetime
@ -8,7 +9,7 @@ from .utils import make_arn_for_lexicon
class Lexicon(BaseModel): class Lexicon(BaseModel):
def __init__(self, name, content, account_id, region_name): def __init__(self, name: str, content: str, account_id: str, region_name: str):
self.name = name self.name = name
self.content = content self.content = content
self.size = 0 self.size = 0
@ -20,7 +21,7 @@ class Lexicon(BaseModel):
self.update() self.update()
def update(self, content=None): def update(self, content: Optional[str] = None) -> None:
if content is not None: if content is not None:
self.content = content self.content = content
@ -28,7 +29,7 @@ class Lexicon(BaseModel):
try: try:
root = ET.fromstring(self.content) root = ET.fromstring(self.content)
self.size = len(self.content) self.size = len(self.content)
self.last_modified = int( self.last_modified = int( # type: ignore
( (
datetime.datetime.now() - datetime.datetime(1970, 1, 1) datetime.datetime.now() - datetime.datetime(1970, 1, 1)
).total_seconds() ).total_seconds()
@ -37,14 +38,14 @@ class Lexicon(BaseModel):
for key, value in root.attrib.items(): for key, value in root.attrib.items():
if key.endswith("alphabet"): if key.endswith("alphabet"):
self.alphabet = value self.alphabet = value # type: ignore
elif key.endswith("lang"): elif key.endswith("lang"):
self.language_code = value self.language_code = value # type: ignore
except Exception as err: except Exception as err:
raise ValueError(f"Failure parsing XML: {err}") raise ValueError(f"Failure parsing XML: {err}")
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"Attributes": { "Attributes": {
"Alphabet": self.alphabet, "Alphabet": self.alphabet,
@ -56,16 +57,16 @@ class Lexicon(BaseModel):
} }
} }
def __repr__(self): def __repr__(self) -> str:
return f"<Lexicon {self.name}>" return f"<Lexicon {self.name}>"
class PollyBackend(BaseBackend): class PollyBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self._lexicons = {} self._lexicons: Dict[str, Lexicon] = {}
def describe_voices(self, language_code): def describe_voices(self, language_code: str) -> List[Dict[str, Any]]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
@ -74,15 +75,15 @@ class PollyBackend(BaseBackend):
return [item for item in VOICE_DATA if item["LanguageCode"] == language_code] return [item for item in VOICE_DATA if item["LanguageCode"] == language_code]
def delete_lexicon(self, name): def delete_lexicon(self, name: str) -> None:
# implement here # implement here
del self._lexicons[name] del self._lexicons[name]
def get_lexicon(self, name): def get_lexicon(self, name: str) -> Lexicon:
# Raises KeyError # Raises KeyError
return self._lexicons[name] return self._lexicons[name]
def list_lexicons(self): def list_lexicons(self) -> List[Dict[str, Any]]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
@ -97,12 +98,12 @@ class PollyBackend(BaseBackend):
return result return result
def put_lexicon(self, name, content): def put_lexicon(self, name: str, content: str) -> None:
# If lexicon content is bad, it will raise ValueError # If lexicon content is bad, it will raise ValueError
if name in self._lexicons: if name in self._lexicons:
# Regenerated all the stats from the XML # Regenerated all the stats from the XML
# but keeps the ARN # but keeps the ARN
self._lexicons.update(content) self._lexicons[name].update(content)
else: else:
lexicon = Lexicon( lexicon = Lexicon(
name, content, self.account_id, region_name=self.region_name name, content, self.account_id, region_name=self.region_name

View File

@ -1,33 +1,33 @@
import json import json
import re import re
from typing import Any, Dict, Tuple, Union
from urllib.parse import urlsplit from urllib.parse import urlsplit
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import polly_backends from .models import polly_backends, PollyBackend
from .resources import LANGUAGE_CODES, VOICE_IDS from .resources import LANGUAGE_CODES, VOICE_IDS
LEXICON_NAME_REGEX = re.compile(r"^[0-9A-Za-z]{1,20}$") LEXICON_NAME_REGEX = re.compile(r"^[0-9A-Za-z]{1,20}$")
class PollyResponse(BaseResponse): class PollyResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="polly") super().__init__(service_name="polly")
@property @property
def polly_backend(self): def polly_backend(self) -> PollyBackend:
return polly_backends[self.current_account][self.region] return polly_backends[self.current_account][self.region]
@property @property
def json(self): def json(self) -> Dict[str, Any]: # type: ignore[misc]
if not hasattr(self, "_json"): if not hasattr(self, "_json"):
self._json = json.loads(self.body) self._json = json.loads(self.body)
return self._json return self._json
def _error(self, code, message): def _error(self, code: str, message: str) -> Tuple[str, Dict[str, int]]:
return json.dumps({"__type": code, "message": message}), dict(status=400) return json.dumps({"__type": code, "message": message}), dict(status=400)
def _get_action(self): def _get_action(self) -> str:
# Amazon is now naming things /v1/api_name # Amazon is now naming things /v1/api_name
url_parts = urlsplit(self.uri).path.lstrip("/").split("/") url_parts = urlsplit(self.uri).path.lstrip("/").split("/")
# [0] = 'v1' # [0] = 'v1'
@ -35,11 +35,11 @@ class PollyResponse(BaseResponse):
return url_parts[1] return url_parts[1]
# DescribeVoices # DescribeVoices
def voices(self): def voices(self) -> Union[str, Tuple[str, Dict[str, int]]]:
language_code = self._get_param("LanguageCode") language_code = self._get_param("LanguageCode")
if language_code is not None and language_code not in LANGUAGE_CODES: if language_code is not None and language_code not in LANGUAGE_CODES:
all_codes = ", ".join(LANGUAGE_CODES) all_codes = ", ".join(LANGUAGE_CODES) # type: ignore
msg = ( msg = (
f"1 validation error detected: Value '{language_code}' at 'languageCode' failed to satisfy constraint: " f"1 validation error detected: Value '{language_code}' at 'languageCode' failed to satisfy constraint: "
f"Member must satisfy enum value set: [{all_codes}]" f"Member must satisfy enum value set: [{all_codes}]"
@ -50,7 +50,7 @@ class PollyResponse(BaseResponse):
return json.dumps({"Voices": voices}) return json.dumps({"Voices": voices})
def lexicons(self): def lexicons(self) -> Union[str, Tuple[str, Dict[str, int]]]:
# Dish out requests based on methods # Dish out requests based on methods
# anything after the /v1/lexicons/ # anything after the /v1/lexicons/
@ -69,7 +69,9 @@ class PollyResponse(BaseResponse):
return self._error("InvalidAction", "Bad route") return self._error("InvalidAction", "Bad route")
# PutLexicon # PutLexicon
def _put_lexicons(self, lexicon_name): def _put_lexicons(
self, lexicon_name: str
) -> Union[str, Tuple[str, Dict[str, int]]]:
if LEXICON_NAME_REGEX.match(lexicon_name) is None: if LEXICON_NAME_REGEX.match(lexicon_name) is None:
return self._error( return self._error(
"InvalidParameterValue", "Lexicon name must match [0-9A-Za-z]{1,20}" "InvalidParameterValue", "Lexicon name must match [0-9A-Za-z]{1,20}"
@ -83,13 +85,13 @@ class PollyResponse(BaseResponse):
return "" return ""
# ListLexicons # ListLexicons
def _get_lexicons_list(self): def _get_lexicons_list(self) -> str:
result = {"Lexicons": self.polly_backend.list_lexicons()} result = {"Lexicons": self.polly_backend.list_lexicons()}
return json.dumps(result) return json.dumps(result)
# GetLexicon # GetLexicon
def _get_lexicon(self, lexicon_name): def _get_lexicon(self, lexicon_name: str) -> Union[str, Tuple[str, Dict[str, int]]]:
try: try:
lexicon = self.polly_backend.get_lexicon(lexicon_name) lexicon = self.polly_backend.get_lexicon(lexicon_name)
except KeyError: except KeyError:
@ -103,7 +105,9 @@ class PollyResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# DeleteLexicon # DeleteLexicon
def _delete_lexicon(self, lexicon_name): def _delete_lexicon(
self, lexicon_name: str
) -> Union[str, Tuple[str, Dict[str, int]]]:
try: try:
self.polly_backend.delete_lexicon(lexicon_name) self.polly_backend.delete_lexicon(lexicon_name)
except KeyError: except KeyError:
@ -112,7 +116,7 @@ class PollyResponse(BaseResponse):
return "" return ""
# SynthesizeSpeech # SynthesizeSpeech
def speech(self): def speech(self) -> Tuple[str, Dict[str, Any]]:
# Sanity check params # Sanity check params
args = { args = {
"lexicon_names": None, "lexicon_names": None,
@ -169,12 +173,12 @@ class PollyResponse(BaseResponse):
if "VoiceId" not in self.json: if "VoiceId" not in self.json:
return self._error("MissingParameter", "Missing parameter VoiceId") return self._error("MissingParameter", "Missing parameter VoiceId")
if self.json["VoiceId"] not in VOICE_IDS: if self.json["VoiceId"] not in VOICE_IDS:
all_voices = ", ".join(VOICE_IDS) all_voices = ", ".join(VOICE_IDS) # type: ignore
return self._error("InvalidParameterValue", f"Not one of {all_voices}") return self._error("InvalidParameterValue", f"Not one of {all_voices}")
args["voice_id"] = self.json["VoiceId"] args["voice_id"] = self.json["VoiceId"]
# More validation # More validation
if len(args["text"]) > 3000: if len(args["text"]) > 3000: # type: ignore
return self._error("TextLengthExceededException", "Text too long") return self._error("TextLengthExceededException", "Text too long")
if args["speech_marks"] is not None and args["output_format"] != "json": if args["speech_marks"] is not None and args["output_format"] != "json":

View File

@ -1,2 +1,2 @@
def make_arn_for_lexicon(account_id, name, region_name): def make_arn_for_lexicon(account_id: str, name: str, region_name: str) -> str:
return f"arn:aws:polly:{region_name}:{account_id}:lexicon/{name}" return f"arn:aws:polly:{region_name}:{account_id}:lexicon/{name}"

View File

@ -235,7 +235,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/rdsdata files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/rdsdata
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract

View File

@ -258,3 +258,10 @@ def test_synthesize_speech_bad_speech_marks2():
) )
else: else:
raise RuntimeError("Should have raised ") raise RuntimeError("Should have raised ")
@mock_polly
def test_update_lexicon():
client = boto3.client("polly", region_name=DEFAULT_REGION)
client.put_lexicon(Name="test", Content=LEXICON_XML)
client.put_lexicon(Name="test", Content=LEXICON_XML)