diff --git a/moto/textract/exceptions.py b/moto/textract/exceptions.py index 9c0e512fc..3391df15e 100644 --- a/moto/textract/exceptions.py +++ b/moto/textract/exceptions.py @@ -5,16 +5,16 @@ from moto.core.exceptions import JsonRESTError class InvalidJobIdException(JsonRESTError): code = 400 - def __init__(self): - super().__init__(__class__.__name__, "An invalid job identifier was passed.") + def __init__(self) -> None: + super().__init__(__class__.__name__, "An invalid job identifier was passed.") # type: ignore class InvalidS3ObjectException(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( - __class__.__name__, + __class__.__name__, # type: ignore "Amazon Textract is unable to access the S3 object that's specified in the request.", ) @@ -22,8 +22,8 @@ class InvalidS3ObjectException(JsonRESTError): class InvalidParameterException(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( - __class__.__name__, + __class__.__name__, # type: ignore "An input parameter violated a constraint. For example, in synchronous operations, an InvalidParameterException exception occurs when neither of the S3Object or Bytes values are supplied in the Document request parameter. Validate your parameter before calling the API operation again.", ) diff --git a/moto/textract/models.py b/moto/textract/models.py index a7526ad16..876acc2f5 100644 --- a/moto/textract/models.py +++ b/moto/textract/models.py @@ -1,6 +1,5 @@ -"""TextractBackend class with methods for supported APIs.""" - from collections import defaultdict +from typing import Any, Dict, List from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api._internal import mock_random @@ -16,10 +15,10 @@ class TextractJobStatus: class TextractJob(BaseModel): - def __init__(self, job): + def __init__(self, job: Dict[str, Any]): self.job = job - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return self.job @@ -28,13 +27,13 @@ class TextractBackend(BaseBackend): JOB_STATUS = TextractJobStatus.succeeded PAGES = {"Pages": mock_random.randint(5, 500)} - BLOCKS = [] + BLOCKS: List[Dict[str, Any]] = [] - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.async_text_detection_jobs = defaultdict() + self.async_text_detection_jobs: Dict[str, TextractJob] = defaultdict() - def get_document_text_detection(self, job_id): + def get_document_text_detection(self, job_id: str) -> TextractJob: """ Pagination has not yet been implemented """ @@ -43,7 +42,7 @@ class TextractBackend(BaseBackend): raise InvalidJobIdException() return job - def start_document_text_detection(self, document_location): + def start_document_text_detection(self, document_location: str) -> str: """ The following parameters have not yet been implemented: ClientRequestToken, JobTag, NotificationChannel, OutputConfig, KmsKeyID """ diff --git a/moto/textract/responses.py b/moto/textract/responses.py index b10949531..df72a5ded 100644 --- a/moto/textract/responses.py +++ b/moto/textract/responses.py @@ -2,27 +2,27 @@ import json from moto.core.responses import BaseResponse -from .models import textract_backends +from .models import textract_backends, TextractBackend class TextractResponse(BaseResponse): """Handler for Textract requests and responses.""" - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="textract") @property - def textract_backend(self): + def textract_backend(self) -> TextractBackend: """Return backend instance specific for this region.""" return textract_backends[self.current_account][self.region] - def get_document_text_detection(self): + def get_document_text_detection(self) -> str: params = json.loads(self.body) job_id = params.get("JobId") job = self.textract_backend.get_document_text_detection(job_id=job_id).to_dict() return json.dumps(job) - def start_document_text_detection(self): + def start_document_text_detection(self) -> str: params = json.loads(self.body) document_location = params.get("DocumentLocation") job_id = self.textract_backend.start_document_text_detection( diff --git a/moto/timestreamwrite/exceptions.py b/moto/timestreamwrite/exceptions.py index f1598bd65..18efc08d5 100644 --- a/moto/timestreamwrite/exceptions.py +++ b/moto/timestreamwrite/exceptions.py @@ -5,5 +5,5 @@ from moto.core.exceptions import JsonRESTError class ResourceNotFound(JsonRESTError): error_type = "com.amazonaws.timestream.v20181101#ResourceNotFoundException" - def __init__(self, msg): + def __init__(self, msg: str): super().__init__(ResourceNotFound.error_type, msg) diff --git a/moto/timestreamwrite/models.py b/moto/timestreamwrite/models.py index c14b8b46e..4aae3226c 100644 --- a/moto/timestreamwrite/models.py +++ b/moto/timestreamwrite/models.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, List, Iterable from moto.core import BaseBackend, BackendDict, BaseModel from moto.utilities.tagging_service import TaggingService from .exceptions import ResourceNotFound @@ -6,12 +7,12 @@ from .exceptions import ResourceNotFound class TimestreamTable(BaseModel): def __init__( self, - account_id, - region_name, - table_name, - db_name, - retention_properties, - magnetic_store_write_properties, + account_id: str, + region_name: str, + table_name: str, + db_name: str, + retention_properties: Dict[str, int], + magnetic_store_write_properties: Dict[str, Any], ): self.region_name = region_name self.name = table_name @@ -21,18 +22,22 @@ class TimestreamTable(BaseModel): "MagneticStoreRetentionPeriodInDays": 123, } self.magnetic_store_write_properties = magnetic_store_write_properties or {} - self.records = [] + self.records: List[Dict[str, Any]] = [] self.arn = f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.db_name}/table/{self.name}" - def update(self, retention_properties, magnetic_store_write_properties): + def update( + self, + retention_properties: Dict[str, int], + magnetic_store_write_properties: Dict[str, Any], + ) -> None: self.retention_properties = retention_properties if magnetic_store_write_properties is not None: self.magnetic_store_write_properties = magnetic_store_write_properties - def write_records(self, records): + def write_records(self, records: List[Dict[str, Any]]) -> None: self.records.extend(records) - def description(self): + def description(self) -> Dict[str, Any]: return { "Arn": self.arn, "TableName": self.name, @@ -44,7 +49,9 @@ class TimestreamTable(BaseModel): class TimestreamDatabase(BaseModel): - def __init__(self, account_id, region_name, database_name, kms_key_id): + def __init__( + self, account_id: str, region_name: str, database_name: str, kms_key_id: str + ): self.account_id = account_id self.region_name = region_name self.name = database_name @@ -54,14 +61,17 @@ class TimestreamDatabase(BaseModel): self.arn = ( f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.name}" ) - self.tables = dict() + self.tables: Dict[str, TimestreamTable] = dict() - def update(self, kms_key_id): + def update(self, kms_key_id: str) -> None: self.kms_key_id = kms_key_id def create_table( - self, table_name, retention_properties, magnetic_store_write_properties - ): + self, + table_name: str, + retention_properties: Dict[str, int], + magnetic_store_write_properties: Dict[str, Any], + ) -> TimestreamTable: table = TimestreamTable( account_id=self.account_id, region_name=self.region_name, @@ -74,8 +84,11 @@ class TimestreamDatabase(BaseModel): return table def update_table( - self, table_name, retention_properties, magnetic_store_write_properties - ): + self, + table_name: str, + retention_properties: Dict[str, int], + magnetic_store_write_properties: Dict[str, Any], + ) -> TimestreamTable: table = self.tables[table_name] table.update( retention_properties=retention_properties, @@ -83,18 +96,18 @@ class TimestreamDatabase(BaseModel): ) return table - def delete_table(self, table_name): + def delete_table(self, table_name: str) -> None: self.tables.pop(table_name, None) - def describe_table(self, table_name): + def describe_table(self, table_name: str) -> TimestreamTable: if table_name not in self.tables: raise ResourceNotFound(f"The table {table_name} does not exist.") return self.tables[table_name] - def list_tables(self): + def list_tables(self) -> Iterable[TimestreamTable]: return self.tables.values() - def description(self): + def description(self) -> Dict[str, Any]: return { "Arn": self.arn, "DatabaseName": self.name, @@ -117,12 +130,14 @@ class TimestreamWriteBackend(BaseBackend): """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.databases = dict() + self.databases: Dict[str, TimestreamDatabase] = dict() self.tagging_service = TaggingService() - def create_database(self, database_name, kms_key_id, tags): + def create_database( + self, database_name: str, kms_key_id: str, tags: List[Dict[str, str]] + ) -> TimestreamDatabase: database = TimestreamDatabase( self.account_id, self.region_name, database_name, kms_key_id ) @@ -130,30 +145,32 @@ class TimestreamWriteBackend(BaseBackend): self.tagging_service.tag_resource(database.arn, tags) return database - def delete_database(self, database_name): + def delete_database(self, database_name: str) -> None: del self.databases[database_name] - def describe_database(self, database_name): + def describe_database(self, database_name: str) -> TimestreamDatabase: if database_name not in self.databases: raise ResourceNotFound(f"The database {database_name} does not exist.") return self.databases[database_name] - def list_databases(self): + def list_databases(self) -> Iterable[TimestreamDatabase]: return self.databases.values() - def update_database(self, database_name, kms_key_id): + def update_database( + self, database_name: str, kms_key_id: str + ) -> TimestreamDatabase: database = self.databases[database_name] database.update(kms_key_id=kms_key_id) return database def create_table( self, - database_name, - table_name, - retention_properties, - tags, - magnetic_store_write_properties, - ): + database_name: str, + table_name: str, + retention_properties: Dict[str, int], + tags: List[Dict[str, str]], + magnetic_store_write_properties: Dict[str, Any], + ) -> TimestreamTable: database = self.describe_database(database_name) table = database.create_table( table_name, retention_properties, magnetic_store_write_properties @@ -161,39 +178,38 @@ class TimestreamWriteBackend(BaseBackend): self.tagging_service.tag_resource(table.arn, tags) return table - def delete_table(self, database_name, table_name): + def delete_table(self, database_name: str, table_name: str) -> None: database = self.describe_database(database_name) database.delete_table(table_name) - def describe_table(self, database_name, table_name): + def describe_table(self, database_name: str, table_name: str) -> TimestreamTable: database = self.describe_database(database_name) - table = database.describe_table(table_name) - return table + return database.describe_table(table_name) - def list_tables(self, database_name): + def list_tables(self, database_name: str) -> Iterable[TimestreamTable]: database = self.describe_database(database_name) - tables = database.list_tables() - return tables + return database.list_tables() def update_table( self, - database_name, - table_name, - retention_properties, - magnetic_store_write_properties, - ): + database_name: str, + table_name: str, + retention_properties: Dict[str, int], + magnetic_store_write_properties: Dict[str, Any], + ) -> TimestreamTable: database = self.describe_database(database_name) - table = database.update_table( + return database.update_table( table_name, retention_properties, magnetic_store_write_properties ) - return table - def write_records(self, database_name, table_name, records): + def write_records( + self, database_name: str, table_name: str, records: List[Dict[str, Any]] + ) -> None: database = self.describe_database(database_name) table = database.describe_table(table_name) table.write_records(records) - def describe_endpoints(self): + def describe_endpoints(self) -> Dict[str, List[Dict[str, Any]]]: # https://docs.aws.amazon.com/timestream/latest/developerguide/Using-API.endpoint-discovery.how-it-works.html # Usually, the address look like this: # ingest-cell1.timestream.us-east-1.amazonaws.com @@ -208,13 +224,15 @@ class TimestreamWriteBackend(BaseBackend): ] } - def list_tags_for_resource(self, resource_arn): + def list_tags_for_resource( + self, resource_arn: str + ) -> Dict[str, List[Dict[str, str]]]: return self.tagging_service.list_tags_for_resource(resource_arn) - def tag_resource(self, resource_arn, tags): + def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None: self.tagging_service.tag_resource(resource_arn, tags) - def untag_resource(self, resource_arn, tag_keys): + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: self.tagging_service.untag_resource_using_names(resource_arn, tag_keys) diff --git a/moto/timestreamwrite/responses.py b/moto/timestreamwrite/responses.py index 2580b3a3b..cfb8e4829 100644 --- a/moto/timestreamwrite/responses.py +++ b/moto/timestreamwrite/responses.py @@ -1,19 +1,19 @@ import json from moto.core.responses import BaseResponse -from .models import timestreamwrite_backends +from .models import timestreamwrite_backends, TimestreamWriteBackend class TimestreamWriteResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="timestream-write") @property - def timestreamwrite_backend(self): + def timestreamwrite_backend(self) -> TimestreamWriteBackend: """Return backend instance specific for this region.""" return timestreamwrite_backends[self.current_account][self.region] - def create_database(self): + def create_database(self) -> str: database_name = self._get_param("DatabaseName") kms_key_id = self._get_param("KmsKeyId") tags = self._get_param("Tags") @@ -22,19 +22,19 @@ class TimestreamWriteResponse(BaseResponse): ) return json.dumps(dict(Database=database.description())) - def delete_database(self): + def delete_database(self) -> str: database_name = self._get_param("DatabaseName") self.timestreamwrite_backend.delete_database(database_name=database_name) return "{}" - def describe_database(self): + def describe_database(self) -> str: database_name = self._get_param("DatabaseName") database = self.timestreamwrite_backend.describe_database( database_name=database_name ) return json.dumps(dict(Database=database.description())) - def update_database(self): + def update_database(self) -> str: database_name = self._get_param("DatabaseName") kms_key_id = self._get_param("KmsKeyId") database = self.timestreamwrite_backend.update_database( @@ -42,11 +42,11 @@ class TimestreamWriteResponse(BaseResponse): ) return json.dumps(dict(Database=database.description())) - def list_databases(self): + def list_databases(self) -> str: all_dbs = self.timestreamwrite_backend.list_databases() return json.dumps(dict(Databases=[db.description() for db in all_dbs])) - def create_table(self): + def create_table(self) -> str: database_name = self._get_param("DatabaseName") table_name = self._get_param("TableName") retention_properties = self._get_param("RetentionProperties") @@ -63,24 +63,24 @@ class TimestreamWriteResponse(BaseResponse): ) return json.dumps(dict(Table=table.description())) - def delete_table(self): + def delete_table(self) -> str: database_name = self._get_param("DatabaseName") table_name = self._get_param("TableName") self.timestreamwrite_backend.delete_table(database_name, table_name) return "{}" - def describe_table(self): + def describe_table(self) -> str: database_name = self._get_param("DatabaseName") table_name = self._get_param("TableName") table = self.timestreamwrite_backend.describe_table(database_name, table_name) return json.dumps(dict(Table=table.description())) - def list_tables(self): + def list_tables(self) -> str: database_name = self._get_param("DatabaseName") tables = self.timestreamwrite_backend.list_tables(database_name) return json.dumps(dict(Tables=[t.description() for t in tables])) - def update_table(self): + def update_table(self) -> str: database_name = self._get_param("DatabaseName") table_name = self._get_param("TableName") retention_properties = self._get_param("RetentionProperties") @@ -95,7 +95,7 @@ class TimestreamWriteResponse(BaseResponse): ) return json.dumps(dict(Table=table.description())) - def write_records(self): + def write_records(self) -> str: database_name = self._get_param("DatabaseName") table_name = self._get_param("TableName") records = self._get_param("Records") @@ -109,22 +109,22 @@ class TimestreamWriteResponse(BaseResponse): } return json.dumps(resp) - def describe_endpoints(self): + def describe_endpoints(self) -> str: resp = self.timestreamwrite_backend.describe_endpoints() return json.dumps(resp) - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> str: resource_arn = self._get_param("ResourceARN") tags = self.timestreamwrite_backend.list_tags_for_resource(resource_arn) return json.dumps(tags) - def tag_resource(self): + def tag_resource(self) -> str: resource_arn = self._get_param("ResourceARN") tags = self._get_param("Tags") self.timestreamwrite_backend.tag_resource(resource_arn, tags) return "{}" - def untag_resource(self): + def untag_resource(self) -> str: resource_arn = self._get_param("ResourceARN") tag_keys = self._get_param("TagKeys") self.timestreamwrite_backend.untag_resource(resource_arn, tag_keys) diff --git a/moto/transcribe/exceptions.py b/moto/transcribe/exceptions.py index 948f5665b..d8381e8db 100644 --- a/moto/transcribe/exceptions.py +++ b/moto/transcribe/exceptions.py @@ -2,10 +2,10 @@ from moto.core.exceptions import JsonRESTError class ConflictException(JsonRESTError): - def __init__(self, message, **kwargs): - super().__init__("ConflictException", message, **kwargs) + def __init__(self, message: str): + super().__init__("ConflictException", message) class BadRequestException(JsonRESTError): - def __init__(self, message, **kwargs): - super().__init__("BadRequestException", message, **kwargs) + def __init__(self, message: str): + super().__init__("BadRequestException", message) diff --git a/moto/transcribe/models.py b/moto/transcribe/models.py index 498389eff..75c1bb955 100644 --- a/moto/transcribe/models.py +++ b/moto/transcribe/models.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api import state_manager from moto.moto_api._internal import mock_random @@ -7,14 +8,14 @@ from .exceptions import ConflictException, BadRequestException class BaseObject(BaseModel): - def camelCase(self, key): + def camelCase(self, key: str) -> str: words = [] for word in key.split("_"): words.append(word.title()) return "".join(words) - def gen_response_object(self): - response_object = dict() + def gen_response_object(self) -> Dict[str, Any]: + response_object: Dict[str, Any] = dict() for key, value in self.__dict__.items(): if "_" in key: response_object[self.camelCase(key)] = value @@ -23,30 +24,30 @@ class BaseObject(BaseModel): return response_object @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] return self.gen_response_object() class FakeTranscriptionJob(BaseObject, ManagedState): def __init__( self, - account_id, - region_name, - transcription_job_name, - language_code, - media_sample_rate_hertz, - media_format, - media, - output_bucket_name, - output_key, - output_encryption_kms_key_id, - settings, - model_settings, - job_execution_settings, - content_redaction, - identify_language, - identify_multiple_languages, - language_options, + account_id: str, + region_name: str, + transcription_job_name: str, + language_code: Optional[str], + media_sample_rate_hertz: Optional[int], + media_format: Optional[str], + media: Dict[str, str], + output_bucket_name: Optional[str], + output_key: Optional[str], + output_encryption_kms_key_id: Optional[str], + settings: Optional[Dict[str, Any]], + model_settings: Optional[Dict[str, Optional[str]]], + job_execution_settings: Optional[Dict[str, Any]], + content_redaction: Optional[Dict[str, Any]], + identify_language: Optional[bool], + identify_multiple_languages: Optional[bool], + language_options: Optional[List[str]], ): ManagedState.__init__( self, @@ -61,12 +62,13 @@ class FakeTranscriptionJob(BaseObject, ManagedState): self._region_name = region_name self.transcription_job_name = transcription_job_name self.language_code = language_code - self.language_codes = None + self.language_codes: Optional[List[Dict[str, Any]]] = None self.media_sample_rate_hertz = media_sample_rate_hertz self.media_format = media_format self.media = media - self.transcript = None - self.start_time = self.completion_time = None + self.transcript: Optional[Dict[str, str]] = None + self.start_time: Optional[str] = None + self.completion_time: Optional[str] = None self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.failure_reason = None self.settings = settings or { @@ -86,7 +88,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): self.identify_language = identify_language self.identify_multiple_languages = identify_multiple_languages self.language_options = language_options - self.identified_language_score = (None,) + self.identified_language_score: Optional[float] = None self._output_bucket_name = output_bucket_name self.output_key = output_key self._output_encryption_kms_key_id = output_encryption_kms_key_id @@ -94,7 +96,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): "CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET" ) - def response_object(self, response_type): + def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore response_field_dict = { "CREATE": [ "TranscriptionJobName", @@ -162,7 +164,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): if k in response_fields and v is not None and v != [None] } - def advance(self): + def advance(self) -> None: old_status = self.status super().advance() new_status = self.status @@ -191,20 +193,20 @@ class FakeTranscriptionJob(BaseObject, ManagedState): self.identified_language_score = 0.999645948 # Identify first two languages passed in language_options # If none is set, default to "en-US" - self.language_codes = [] + self.language_codes: List[Dict[str, Any]] = [] # type: ignore[no-redef] if self.language_options is None or len(self.language_options) == 0: - self.language_codes.append( + self.language_codes.append( # type: ignore {"LanguageCode": "en-US", "DurationInSeconds": 123.0} ) else: - self.language_codes.append( + self.language_codes.append( # type: ignore { "LanguageCode": self.language_options[0], "DurationInSeconds": 123.0, } ) if len(self.language_options) > 1: - self.language_codes.append( + self.language_codes.append( # type: ignore { "LanguageCode": self.language_options[1], "DurationInSeconds": 321.0, @@ -229,12 +231,12 @@ class FakeTranscriptionJob(BaseObject, ManagedState): class FakeVocabulary(BaseObject, ManagedState): def __init__( self, - account_id, - region_name, - vocabulary_name, - language_code, - phrases, - vocabulary_file_uri, + account_id: str, + region_name: str, + vocabulary_name: str, + language_code: str, + phrases: Optional[List[str]], + vocabulary_file_uri: Optional[str], ): # Configured ManagedState super().__init__( @@ -247,11 +249,11 @@ class FakeVocabulary(BaseObject, ManagedState): self.language_code = language_code self.phrases = phrases self.vocabulary_file_uri = vocabulary_file_uri - self.last_modified_time = None + self.last_modified_time: Optional[str] = None self.failure_reason = None self.download_uri = f"https://s3.{region_name}.amazonaws.com/aws-transcribe-dictionary-model-{region_name}-prod/{account_id}/{vocabulary_name}/{mock_random.uuid4()}/input.txt" - def response_object(self, response_type): + def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore response_field_dict = { "CREATE": [ "VocabularyName", @@ -284,7 +286,7 @@ class FakeVocabulary(BaseObject, ManagedState): if k in response_fields and v is not None and v != [None] } - def advance(self): + def advance(self) -> None: old_status = self.status super().advance() new_status = self.status @@ -296,17 +298,17 @@ class FakeVocabulary(BaseObject, ManagedState): class FakeMedicalTranscriptionJob(BaseObject, ManagedState): def __init__( self, - region_name, - medical_transcription_job_name, - language_code, - media_sample_rate_hertz, - media_format, - media, - output_bucket_name, - output_encryption_kms_key_id, - settings, - specialty, - job_type, + region_name: str, + medical_transcription_job_name: str, + language_code: str, + media_sample_rate_hertz: Optional[int], + media_format: Optional[str], + media: Dict[str, str], + output_bucket_name: str, + output_encryption_kms_key_id: Optional[str], + settings: Optional[Dict[str, Any]], + specialty: str, + job_type: str, ): ManagedState.__init__( self, @@ -323,8 +325,9 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState): self.media_sample_rate_hertz = media_sample_rate_hertz self.media_format = media_format self.media = media - self.transcript = None - self.start_time = self.completion_time = None + self.transcript: Optional[Dict[str, str]] = None + self.start_time: Optional[str] = None + self.completion_time: Optional[str] = None self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.failure_reason = None self.settings = settings or { @@ -337,7 +340,7 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState): self._output_encryption_kms_key_id = output_encryption_kms_key_id self.output_location_type = "CUSTOMER_BUCKET" - def response_object(self, response_type): + def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore response_field_dict = { "CREATE": [ "MedicalTranscriptionJobName", @@ -396,7 +399,7 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState): if k in response_fields and v is not None and v != [None] } - def advance(self): + def advance(self) -> None: old_status = self.status super().advance() new_status = self.status @@ -425,11 +428,11 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState): class FakeMedicalVocabulary(FakeVocabulary): def __init__( self, - account_id, - region_name, - vocabulary_name, - language_code, - vocabulary_file_uri, + account_id: str, + region_name: str, + vocabulary_name: str, + language_code: str, + vocabulary_file_uri: Optional[str], ): super().__init__( account_id, @@ -450,12 +453,12 @@ class FakeMedicalVocabulary(FakeVocabulary): class TranscribeBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.medical_transcriptions = {} - self.transcriptions = {} - self.medical_vocabularies = {} - self.vocabularies = {} + self.medical_transcriptions: Dict[str, FakeMedicalTranscriptionJob] = {} + self.transcriptions: Dict[str, FakeTranscriptionJob] = {} + self.medical_vocabularies: Dict[str, FakeMedicalVocabulary] = {} + self.vocabularies: Dict[str, FakeVocabulary] = {} state_manager.register_default_transition( "transcribe::vocabulary", transition={"progression": "manual", "times": 1} @@ -474,7 +477,9 @@ class TranscribeBackend(BaseBackend): ) @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint services.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "transcribe" @@ -482,15 +487,29 @@ class TranscribeBackend(BaseBackend): service_region, zones, "transcribestreaming" ) - def start_transcription_job(self, **kwargs): - - name = kwargs.get("transcription_job_name") - if name in self.transcriptions: + def start_transcription_job( + self, + transcription_job_name: str, + language_code: Optional[str], + media_sample_rate_hertz: Optional[int], + media_format: Optional[str], + media: Dict[str, str], + output_bucket_name: Optional[str], + output_key: Optional[str], + output_encryption_kms_key_id: Optional[str], + settings: Optional[Dict[str, Any]], + model_settings: Optional[Dict[str, Optional[str]]], + job_execution_settings: Optional[Dict[str, Any]], + content_redaction: Optional[Dict[str, Any]], + identify_language: Optional[bool], + identify_multiple_languages: Optional[bool], + language_options: Optional[List[str]], + ) -> Dict[str, Any]: + if transcription_job_name in self.transcriptions: raise ConflictException( message="The requested job name already exists. Use a different job name." ) - settings = kwargs.get("settings") vocabulary_name = settings.get("VocabularyName") if settings else None if vocabulary_name and vocabulary_name not in self.vocabularies: raise BadRequestException( @@ -501,36 +520,45 @@ class TranscribeBackend(BaseBackend): transcription_job_object = FakeTranscriptionJob( account_id=self.account_id, region_name=self.region_name, - transcription_job_name=name, - language_code=kwargs.get("language_code"), - media_sample_rate_hertz=kwargs.get("media_sample_rate_hertz"), - media_format=kwargs.get("media_format"), - media=kwargs.get("media"), - output_bucket_name=kwargs.get("output_bucket_name"), - output_key=kwargs.get("output_key"), - output_encryption_kms_key_id=kwargs.get("output_encryption_kms_key_id"), + transcription_job_name=transcription_job_name, + language_code=language_code, + media_sample_rate_hertz=media_sample_rate_hertz, + media_format=media_format, + media=media, + output_bucket_name=output_bucket_name, + output_key=output_key, + output_encryption_kms_key_id=output_encryption_kms_key_id, settings=settings, - model_settings=kwargs.get("model_settings"), - job_execution_settings=kwargs.get("job_execution_settings"), - content_redaction=kwargs.get("content_redaction"), - identify_language=kwargs.get("identify_language"), - identify_multiple_languages=kwargs.get("identify_multiple_languages"), - language_options=kwargs.get("language_options"), + model_settings=model_settings, + job_execution_settings=job_execution_settings, + content_redaction=content_redaction, + identify_language=identify_language, + identify_multiple_languages=identify_multiple_languages, + language_options=language_options, ) - self.transcriptions[name] = transcription_job_object + self.transcriptions[transcription_job_name] = transcription_job_object return transcription_job_object.response_object("CREATE") - def start_medical_transcription_job(self, **kwargs): + def start_medical_transcription_job( + self, + medical_transcription_job_name: str, + language_code: str, + media_sample_rate_hertz: Optional[int], + media_format: Optional[str], + media: Dict[str, str], + output_bucket_name: str, + output_encryption_kms_key_id: Optional[str], + settings: Optional[Dict[str, Any]], + specialty: str, + type_: str, + ) -> Dict[str, Any]: - name = kwargs.get("medical_transcription_job_name") - - if name in self.medical_transcriptions: + if medical_transcription_job_name in self.medical_transcriptions: raise ConflictException( message="The requested job name already exists. Use a different job name." ) - settings = kwargs.get("settings") vocabulary_name = settings.get("VocabularyName") if settings else None if vocabulary_name and vocabulary_name not in self.medical_vocabularies: raise BadRequestException( @@ -540,23 +568,25 @@ class TranscribeBackend(BaseBackend): transcription_job_object = FakeMedicalTranscriptionJob( region_name=self.region_name, - medical_transcription_job_name=name, - language_code=kwargs.get("language_code"), - media_sample_rate_hertz=kwargs.get("media_sample_rate_hertz"), - media_format=kwargs.get("media_format"), - media=kwargs.get("media"), - output_bucket_name=kwargs.get("output_bucket_name"), - output_encryption_kms_key_id=kwargs.get("output_encryption_kms_key_id"), + medical_transcription_job_name=medical_transcription_job_name, + language_code=language_code, + media_sample_rate_hertz=media_sample_rate_hertz, + media_format=media_format, + media=media, + output_bucket_name=output_bucket_name, + output_encryption_kms_key_id=output_encryption_kms_key_id, settings=settings, - specialty=kwargs.get("specialty"), - job_type=kwargs.get("type"), + specialty=specialty, + job_type=type_, ) - self.medical_transcriptions[name] = transcription_job_object + self.medical_transcriptions[ + medical_transcription_job_name + ] = transcription_job_object return transcription_job_object.response_object("CREATE") - def get_transcription_job(self, transcription_job_name): + def get_transcription_job(self, transcription_job_name: str) -> Dict[str, Any]: try: job = self.transcriptions[transcription_job_name] job.advance() # Fakes advancement through statuses. @@ -567,7 +597,9 @@ class TranscribeBackend(BaseBackend): "Check the job name and try your request again." ) - def get_medical_transcription_job(self, medical_transcription_job_name): + def get_medical_transcription_job( + self, medical_transcription_job_name: str + ) -> Dict[str, Any]: try: job = self.medical_transcriptions[medical_transcription_job_name] job.advance() # Fakes advancement through statuses. @@ -578,7 +610,7 @@ class TranscribeBackend(BaseBackend): "Check the job name and try your request again." ) - def delete_transcription_job(self, transcription_job_name): + def delete_transcription_job(self, transcription_job_name: str) -> None: try: del self.transcriptions[transcription_job_name] except KeyError: @@ -587,7 +619,9 @@ class TranscribeBackend(BaseBackend): "Check the job name and try your request again." ) - def delete_medical_transcription_job(self, medical_transcription_job_name): + def delete_medical_transcription_job( + self, medical_transcription_job_name: str + ) -> None: try: del self.medical_transcriptions[medical_transcription_job_name] except KeyError: @@ -597,8 +631,12 @@ class TranscribeBackend(BaseBackend): ) def list_transcription_jobs( - self, state_equals, job_name_contains, next_token, max_results - ): + self, + state_equals: str, + job_name_contains: str, + next_token: str, + max_results: int, + ) -> Dict[str, Any]: jobs = list(self.transcriptions.values()) if state_equals: @@ -615,7 +653,7 @@ class TranscribeBackend(BaseBackend): ) # Arbitrarily selected... jobs_paginated = jobs[start_offset:end_offset] - response = { + response: Dict[str, Any] = { "TranscriptionJobSummaries": [ job.response_object("LIST") for job in jobs_paginated ] @@ -627,8 +665,8 @@ class TranscribeBackend(BaseBackend): return response def list_medical_transcription_jobs( - self, status, job_name_contains, next_token, max_results - ): + self, status: str, job_name_contains: str, next_token: str, max_results: int + ) -> Dict[str, Any]: jobs = list(self.medical_transcriptions.values()) if status: @@ -647,7 +685,7 @@ class TranscribeBackend(BaseBackend): ) # Arbitrarily selected... jobs_paginated = jobs[start_offset:end_offset] - response = { + response: Dict[str, Any] = { "MedicalTranscriptionJobSummaries": [ job.response_object("LIST") for job in jobs_paginated ] @@ -658,12 +696,13 @@ class TranscribeBackend(BaseBackend): response["Status"] = status return response - def create_vocabulary(self, **kwargs): - - vocabulary_name = kwargs.get("vocabulary_name") - language_code = kwargs.get("language_code") - phrases = kwargs.get("phrases") - vocabulary_file_uri = kwargs.get("vocabulary_file_uri") + def create_vocabulary( + self, + vocabulary_name: str, + language_code: str, + phrases: Optional[List[str]], + vocabulary_file_uri: Optional[str], + ) -> Dict[str, Any]: if ( phrases is not None and vocabulary_file_uri is not None @@ -698,12 +737,12 @@ class TranscribeBackend(BaseBackend): return vocabulary_object.response_object("CREATE") - def create_medical_vocabulary(self, **kwargs): - - vocabulary_name = kwargs.get("vocabulary_name") - language_code = kwargs.get("language_code") - vocabulary_file_uri = kwargs.get("vocabulary_file_uri") - + def create_medical_vocabulary( + self, + vocabulary_name: str, + language_code: str, + vocabulary_file_uri: Optional[str], + ) -> Dict[str, Any]: if vocabulary_name in self.medical_vocabularies: raise ConflictException( message="The requested vocabulary name already exists. " @@ -722,7 +761,7 @@ class TranscribeBackend(BaseBackend): return medical_vocabulary_object.response_object("CREATE") - def get_vocabulary(self, vocabulary_name): + def get_vocabulary(self, vocabulary_name: str) -> Dict[str, Any]: try: job = self.vocabularies[vocabulary_name] job.advance() # Fakes advancement through statuses. @@ -733,7 +772,7 @@ class TranscribeBackend(BaseBackend): "Check the vocabulary name and try your request again." ) - def get_medical_vocabulary(self, vocabulary_name): + def get_medical_vocabulary(self, vocabulary_name: str) -> Dict[str, Any]: try: job = self.medical_vocabularies[vocabulary_name] job.advance() # Fakes advancement through statuses. @@ -744,7 +783,7 @@ class TranscribeBackend(BaseBackend): "Check the vocabulary name and try your request again." ) - def delete_vocabulary(self, vocabulary_name): + def delete_vocabulary(self, vocabulary_name: str) -> None: try: del self.vocabularies[vocabulary_name] except KeyError: @@ -752,7 +791,7 @@ class TranscribeBackend(BaseBackend): message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again." ) - def delete_medical_vocabulary(self, vocabulary_name): + def delete_medical_vocabulary(self, vocabulary_name: str) -> None: try: del self.medical_vocabularies[vocabulary_name] except KeyError: @@ -760,7 +799,9 @@ class TranscribeBackend(BaseBackend): message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again." ) - def list_vocabularies(self, state_equals, name_contains, next_token, max_results): + def list_vocabularies( + self, state_equals: str, name_contains: str, next_token: str, max_results: int + ) -> Dict[str, Any]: vocabularies = list(self.vocabularies.values()) if state_equals: @@ -783,7 +824,7 @@ class TranscribeBackend(BaseBackend): ) # Arbitrarily selected... vocabularies_paginated = vocabularies[start_offset:end_offset] - response = { + response: Dict[str, Any] = { "Vocabularies": [ vocabulary.response_object("LIST") for vocabulary in vocabularies_paginated @@ -796,8 +837,8 @@ class TranscribeBackend(BaseBackend): return response def list_medical_vocabularies( - self, state_equals, name_contains, next_token, max_results - ): + self, state_equals: str, name_contains: str, next_token: str, max_results: int + ) -> Dict[str, Any]: vocabularies = list(self.medical_vocabularies.values()) if state_equals: @@ -820,7 +861,7 @@ class TranscribeBackend(BaseBackend): ) # Arbitrarily selected... vocabularies_paginated = vocabularies[start_offset:end_offset] - response = { + response: Dict[str, Any] = { "Vocabularies": [ vocabulary.response_object("LIST") for vocabulary in vocabularies_paginated diff --git a/moto/transcribe/responses.py b/moto/transcribe/responses.py index 6fb54c397..a74771c3a 100644 --- a/moto/transcribe/responses.py +++ b/moto/transcribe/responses.py @@ -2,26 +2,19 @@ import json from moto.core.responses import BaseResponse from moto.utilities.aws_headers import amzn_request_id -from .models import transcribe_backends +from .models import transcribe_backends, TranscribeBackend class TranscribeResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="transcribe") @property - def transcribe_backend(self): + def transcribe_backend(self) -> TranscribeBackend: return transcribe_backends[self.current_account][self.region] - @property - def request_params(self): - try: - return json.loads(self.body) - except ValueError: - return {} - @amzn_request_id - def start_transcription_job(self): + def start_transcription_job(self) -> str: name = self._get_param("TranscriptionJobName") response = self.transcribe_backend.start_transcription_job( transcription_job_name=name, @@ -43,7 +36,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def start_medical_transcription_job(self): + def start_medical_transcription_job(self) -> str: name = self._get_param("MedicalTranscriptionJobName") response = self.transcribe_backend.start_medical_transcription_job( medical_transcription_job_name=name, @@ -55,12 +48,12 @@ class TranscribeResponse(BaseResponse): output_encryption_kms_key_id=self._get_param("OutputEncryptionKMSKeyId"), settings=self._get_param("Settings"), specialty=self._get_param("Specialty"), - type=self._get_param("Type"), + type_=self._get_param("Type"), ) return json.dumps(response) @amzn_request_id - def list_transcription_jobs(self): + def list_transcription_jobs(self) -> str: state_equals = self._get_param("Status") job_name_contains = self._get_param("JobNameContains") next_token = self._get_param("NextToken") @@ -75,7 +68,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def list_medical_transcription_jobs(self): + def list_medical_transcription_jobs(self) -> str: status = self._get_param("Status") job_name_contains = self._get_param("JobNameContains") next_token = self._get_param("NextToken") @@ -90,7 +83,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def get_transcription_job(self): + def get_transcription_job(self) -> str: transcription_job_name = self._get_param("TranscriptionJobName") response = self.transcribe_backend.get_transcription_job( transcription_job_name=transcription_job_name @@ -98,7 +91,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def get_medical_transcription_job(self): + def get_medical_transcription_job(self) -> str: medical_transcription_job_name = self._get_param("MedicalTranscriptionJobName") response = self.transcribe_backend.get_medical_transcription_job( medical_transcription_job_name=medical_transcription_job_name @@ -106,23 +99,23 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def delete_transcription_job(self): + def delete_transcription_job(self) -> str: transcription_job_name = self._get_param("TranscriptionJobName") - response = self.transcribe_backend.delete_transcription_job( + self.transcribe_backend.delete_transcription_job( transcription_job_name=transcription_job_name ) - return json.dumps(response) + return "{}" @amzn_request_id - def delete_medical_transcription_job(self): + def delete_medical_transcription_job(self) -> str: medical_transcription_job_name = self._get_param("MedicalTranscriptionJobName") - response = self.transcribe_backend.delete_medical_transcription_job( + self.transcribe_backend.delete_medical_transcription_job( medical_transcription_job_name=medical_transcription_job_name ) - return json.dumps(response) + return "{}" @amzn_request_id - def create_vocabulary(self): + def create_vocabulary(self) -> str: vocabulary_name = self._get_param("VocabularyName") language_code = self._get_param("LanguageCode") phrases = self._get_param("Phrases") @@ -136,7 +129,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def create_medical_vocabulary(self): + def create_medical_vocabulary(self) -> str: vocabulary_name = self._get_param("VocabularyName") language_code = self._get_param("LanguageCode") vocabulary_file_uri = self._get_param("VocabularyFileUri") @@ -148,7 +141,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def get_vocabulary(self): + def get_vocabulary(self) -> str: vocabulary_name = self._get_param("VocabularyName") response = self.transcribe_backend.get_vocabulary( vocabulary_name=vocabulary_name @@ -156,7 +149,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def get_medical_vocabulary(self): + def get_medical_vocabulary(self) -> str: vocabulary_name = self._get_param("VocabularyName") response = self.transcribe_backend.get_medical_vocabulary( vocabulary_name=vocabulary_name @@ -164,7 +157,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def list_vocabularies(self): + def list_vocabularies(self) -> str: state_equals = self._get_param("StateEquals") name_contains = self._get_param("NameContains") next_token = self._get_param("NextToken") @@ -179,7 +172,7 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def list_medical_vocabularies(self): + def list_medical_vocabularies(self) -> str: state_equals = self._get_param("StateEquals") name_contains = self._get_param("NameContains") next_token = self._get_param("NextToken") @@ -194,17 +187,15 @@ class TranscribeResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def delete_vocabulary(self): + def delete_vocabulary(self) -> str: vocabulary_name = self._get_param("VocabularyName") - response = self.transcribe_backend.delete_vocabulary( - vocabulary_name=vocabulary_name - ) - return json.dumps(response) + self.transcribe_backend.delete_vocabulary(vocabulary_name=vocabulary_name) + return "{}" @amzn_request_id - def delete_medical_vocabulary(self): + def delete_medical_vocabulary(self) -> str: vocabulary_name = self._get_param("VocabularyName") - response = self.transcribe_backend.delete_medical_vocabulary( + self.transcribe_backend.delete_medical_vocabulary( vocabulary_name=vocabulary_name ) - return json.dumps(response) + return "{}" diff --git a/setup.cfg b/setup.cfg index 8fe537bee..6d37533a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u* +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u*,moto/t* show_column_numbers=True show_error_codes = True disable_error_code=abstract