Techdebt: MyPy T (#6270)

This commit is contained in:
Bert Blommers 2023-04-29 22:21:00 +00:00 committed by GitHub
parent e5e1521523
commit 7d6afe4b67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 315 additions and 266 deletions

View File

@ -5,16 +5,16 @@ from moto.core.exceptions import JsonRESTError
class InvalidJobIdException(JsonRESTError): class InvalidJobIdException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__(__class__.__name__, "An invalid job identifier was passed.") super().__init__(__class__.__name__, "An invalid job identifier was passed.") # type: ignore
class InvalidS3ObjectException(JsonRESTError): class InvalidS3ObjectException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
__class__.__name__, __class__.__name__, # type: ignore
"Amazon Textract is unable to access the S3 object that's specified in the request.", "Amazon Textract is unable to access the S3 object that's specified in the request.",
) )
@ -22,8 +22,8 @@ class InvalidS3ObjectException(JsonRESTError):
class InvalidParameterException(JsonRESTError): class InvalidParameterException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
__class__.__name__, __class__.__name__, # type: ignore
"An input parameter violated a constraint. For example, in synchronous operations, an InvalidParameterException exception occurs when neither of the S3Object or Bytes values are supplied in the Document request parameter. Validate your parameter before calling the API operation again.", "An input parameter violated a constraint. For example, in synchronous operations, an InvalidParameterException exception occurs when neither of the S3Object or Bytes values are supplied in the Document request parameter. Validate your parameter before calling the API operation again.",
) )

View File

@ -1,6 +1,5 @@
"""TextractBackend class with methods for supported APIs."""
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
@ -16,10 +15,10 @@ class TextractJobStatus:
class TextractJob(BaseModel): class TextractJob(BaseModel):
def __init__(self, job): def __init__(self, job: Dict[str, Any]):
self.job = job self.job = job
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return self.job return self.job
@ -28,13 +27,13 @@ class TextractBackend(BaseBackend):
JOB_STATUS = TextractJobStatus.succeeded JOB_STATUS = TextractJobStatus.succeeded
PAGES = {"Pages": mock_random.randint(5, 500)} PAGES = {"Pages": mock_random.randint(5, 500)}
BLOCKS = [] BLOCKS: List[Dict[str, Any]] = []
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.async_text_detection_jobs = defaultdict() self.async_text_detection_jobs: Dict[str, TextractJob] = defaultdict()
def get_document_text_detection(self, job_id): def get_document_text_detection(self, job_id: str) -> TextractJob:
""" """
Pagination has not yet been implemented Pagination has not yet been implemented
""" """
@ -43,7 +42,7 @@ class TextractBackend(BaseBackend):
raise InvalidJobIdException() raise InvalidJobIdException()
return job return job
def start_document_text_detection(self, document_location): def start_document_text_detection(self, document_location: str) -> str:
""" """
The following parameters have not yet been implemented: ClientRequestToken, JobTag, NotificationChannel, OutputConfig, KmsKeyID The following parameters have not yet been implemented: ClientRequestToken, JobTag, NotificationChannel, OutputConfig, KmsKeyID
""" """

View File

@ -2,27 +2,27 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import textract_backends from .models import textract_backends, TextractBackend
class TextractResponse(BaseResponse): class TextractResponse(BaseResponse):
"""Handler for Textract requests and responses.""" """Handler for Textract requests and responses."""
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="textract") super().__init__(service_name="textract")
@property @property
def textract_backend(self): def textract_backend(self) -> TextractBackend:
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return textract_backends[self.current_account][self.region] return textract_backends[self.current_account][self.region]
def get_document_text_detection(self): def get_document_text_detection(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
job_id = params.get("JobId") job_id = params.get("JobId")
job = self.textract_backend.get_document_text_detection(job_id=job_id).to_dict() job = self.textract_backend.get_document_text_detection(job_id=job_id).to_dict()
return json.dumps(job) return json.dumps(job)
def start_document_text_detection(self): def start_document_text_detection(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
document_location = params.get("DocumentLocation") document_location = params.get("DocumentLocation")
job_id = self.textract_backend.start_document_text_detection( job_id = self.textract_backend.start_document_text_detection(

View File

@ -5,5 +5,5 @@ from moto.core.exceptions import JsonRESTError
class ResourceNotFound(JsonRESTError): class ResourceNotFound(JsonRESTError):
error_type = "com.amazonaws.timestream.v20181101#ResourceNotFoundException" error_type = "com.amazonaws.timestream.v20181101#ResourceNotFoundException"
def __init__(self, msg): def __init__(self, msg: str):
super().__init__(ResourceNotFound.error_type, msg) super().__init__(ResourceNotFound.error_type, msg)

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Iterable
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
from .exceptions import ResourceNotFound from .exceptions import ResourceNotFound
@ -6,12 +7,12 @@ from .exceptions import ResourceNotFound
class TimestreamTable(BaseModel): class TimestreamTable(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region_name, region_name: str,
table_name, table_name: str,
db_name, db_name: str,
retention_properties, retention_properties: Dict[str, int],
magnetic_store_write_properties, magnetic_store_write_properties: Dict[str, Any],
): ):
self.region_name = region_name self.region_name = region_name
self.name = table_name self.name = table_name
@ -21,18 +22,22 @@ class TimestreamTable(BaseModel):
"MagneticStoreRetentionPeriodInDays": 123, "MagneticStoreRetentionPeriodInDays": 123,
} }
self.magnetic_store_write_properties = magnetic_store_write_properties or {} self.magnetic_store_write_properties = magnetic_store_write_properties or {}
self.records = [] self.records: List[Dict[str, Any]] = []
self.arn = f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.db_name}/table/{self.name}" self.arn = f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.db_name}/table/{self.name}"
def update(self, retention_properties, magnetic_store_write_properties): def update(
self,
retention_properties: Dict[str, int],
magnetic_store_write_properties: Dict[str, Any],
) -> None:
self.retention_properties = retention_properties self.retention_properties = retention_properties
if magnetic_store_write_properties is not None: if magnetic_store_write_properties is not None:
self.magnetic_store_write_properties = magnetic_store_write_properties self.magnetic_store_write_properties = magnetic_store_write_properties
def write_records(self, records): def write_records(self, records: List[Dict[str, Any]]) -> None:
self.records.extend(records) self.records.extend(records)
def description(self): def description(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"TableName": self.name, "TableName": self.name,
@ -44,7 +49,9 @@ class TimestreamTable(BaseModel):
class TimestreamDatabase(BaseModel): class TimestreamDatabase(BaseModel):
def __init__(self, account_id, region_name, database_name, kms_key_id): def __init__(
self, account_id: str, region_name: str, database_name: str, kms_key_id: str
):
self.account_id = account_id self.account_id = account_id
self.region_name = region_name self.region_name = region_name
self.name = database_name self.name = database_name
@ -54,14 +61,17 @@ class TimestreamDatabase(BaseModel):
self.arn = ( self.arn = (
f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.name}" f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.name}"
) )
self.tables = dict() self.tables: Dict[str, TimestreamTable] = dict()
def update(self, kms_key_id): def update(self, kms_key_id: str) -> None:
self.kms_key_id = kms_key_id self.kms_key_id = kms_key_id
def create_table( def create_table(
self, table_name, retention_properties, magnetic_store_write_properties self,
): table_name: str,
retention_properties: Dict[str, int],
magnetic_store_write_properties: Dict[str, Any],
) -> TimestreamTable:
table = TimestreamTable( table = TimestreamTable(
account_id=self.account_id, account_id=self.account_id,
region_name=self.region_name, region_name=self.region_name,
@ -74,8 +84,11 @@ class TimestreamDatabase(BaseModel):
return table return table
def update_table( def update_table(
self, table_name, retention_properties, magnetic_store_write_properties self,
): table_name: str,
retention_properties: Dict[str, int],
magnetic_store_write_properties: Dict[str, Any],
) -> TimestreamTable:
table = self.tables[table_name] table = self.tables[table_name]
table.update( table.update(
retention_properties=retention_properties, retention_properties=retention_properties,
@ -83,18 +96,18 @@ class TimestreamDatabase(BaseModel):
) )
return table return table
def delete_table(self, table_name): def delete_table(self, table_name: str) -> None:
self.tables.pop(table_name, None) self.tables.pop(table_name, None)
def describe_table(self, table_name): def describe_table(self, table_name: str) -> TimestreamTable:
if table_name not in self.tables: if table_name not in self.tables:
raise ResourceNotFound(f"The table {table_name} does not exist.") raise ResourceNotFound(f"The table {table_name} does not exist.")
return self.tables[table_name] return self.tables[table_name]
def list_tables(self): def list_tables(self) -> Iterable[TimestreamTable]:
return self.tables.values() return self.tables.values()
def description(self): def description(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"DatabaseName": self.name, "DatabaseName": self.name,
@ -117,12 +130,14 @@ class TimestreamWriteBackend(BaseBackend):
""" """
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.databases = dict() self.databases: Dict[str, TimestreamDatabase] = dict()
self.tagging_service = TaggingService() self.tagging_service = TaggingService()
def create_database(self, database_name, kms_key_id, tags): def create_database(
self, database_name: str, kms_key_id: str, tags: List[Dict[str, str]]
) -> TimestreamDatabase:
database = TimestreamDatabase( database = TimestreamDatabase(
self.account_id, self.region_name, database_name, kms_key_id self.account_id, self.region_name, database_name, kms_key_id
) )
@ -130,30 +145,32 @@ class TimestreamWriteBackend(BaseBackend):
self.tagging_service.tag_resource(database.arn, tags) self.tagging_service.tag_resource(database.arn, tags)
return database return database
def delete_database(self, database_name): def delete_database(self, database_name: str) -> None:
del self.databases[database_name] del self.databases[database_name]
def describe_database(self, database_name): def describe_database(self, database_name: str) -> TimestreamDatabase:
if database_name not in self.databases: if database_name not in self.databases:
raise ResourceNotFound(f"The database {database_name} does not exist.") raise ResourceNotFound(f"The database {database_name} does not exist.")
return self.databases[database_name] return self.databases[database_name]
def list_databases(self): def list_databases(self) -> Iterable[TimestreamDatabase]:
return self.databases.values() return self.databases.values()
def update_database(self, database_name, kms_key_id): def update_database(
self, database_name: str, kms_key_id: str
) -> TimestreamDatabase:
database = self.databases[database_name] database = self.databases[database_name]
database.update(kms_key_id=kms_key_id) database.update(kms_key_id=kms_key_id)
return database return database
def create_table( def create_table(
self, self,
database_name, database_name: str,
table_name, table_name: str,
retention_properties, retention_properties: Dict[str, int],
tags, tags: List[Dict[str, str]],
magnetic_store_write_properties, magnetic_store_write_properties: Dict[str, Any],
): ) -> TimestreamTable:
database = self.describe_database(database_name) database = self.describe_database(database_name)
table = database.create_table( table = database.create_table(
table_name, retention_properties, magnetic_store_write_properties table_name, retention_properties, magnetic_store_write_properties
@ -161,39 +178,38 @@ class TimestreamWriteBackend(BaseBackend):
self.tagging_service.tag_resource(table.arn, tags) self.tagging_service.tag_resource(table.arn, tags)
return table return table
def delete_table(self, database_name, table_name): def delete_table(self, database_name: str, table_name: str) -> None:
database = self.describe_database(database_name) database = self.describe_database(database_name)
database.delete_table(table_name) database.delete_table(table_name)
def describe_table(self, database_name, table_name): def describe_table(self, database_name: str, table_name: str) -> TimestreamTable:
database = self.describe_database(database_name) database = self.describe_database(database_name)
table = database.describe_table(table_name) return database.describe_table(table_name)
return table
def list_tables(self, database_name): def list_tables(self, database_name: str) -> Iterable[TimestreamTable]:
database = self.describe_database(database_name) database = self.describe_database(database_name)
tables = database.list_tables() return database.list_tables()
return tables
def update_table( def update_table(
self, self,
database_name, database_name: str,
table_name, table_name: str,
retention_properties, retention_properties: Dict[str, int],
magnetic_store_write_properties, magnetic_store_write_properties: Dict[str, Any],
): ) -> TimestreamTable:
database = self.describe_database(database_name) database = self.describe_database(database_name)
table = database.update_table( return database.update_table(
table_name, retention_properties, magnetic_store_write_properties table_name, retention_properties, magnetic_store_write_properties
) )
return table
def write_records(self, database_name, table_name, records): def write_records(
self, database_name: str, table_name: str, records: List[Dict[str, Any]]
) -> None:
database = self.describe_database(database_name) database = self.describe_database(database_name)
table = database.describe_table(table_name) table = database.describe_table(table_name)
table.write_records(records) table.write_records(records)
def describe_endpoints(self): def describe_endpoints(self) -> Dict[str, List[Dict[str, Any]]]:
# https://docs.aws.amazon.com/timestream/latest/developerguide/Using-API.endpoint-discovery.how-it-works.html # https://docs.aws.amazon.com/timestream/latest/developerguide/Using-API.endpoint-discovery.how-it-works.html
# Usually, the address look like this: # Usually, the address look like this:
# ingest-cell1.timestream.us-east-1.amazonaws.com # ingest-cell1.timestream.us-east-1.amazonaws.com
@ -208,13 +224,15 @@ class TimestreamWriteBackend(BaseBackend):
] ]
} }
def list_tags_for_resource(self, resource_arn): def list_tags_for_resource(
self, resource_arn: str
) -> Dict[str, List[Dict[str, str]]]:
return self.tagging_service.list_tags_for_resource(resource_arn) return self.tagging_service.list_tags_for_resource(resource_arn)
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self.tagging_service.tag_resource(resource_arn, tags) self.tagging_service.tag_resource(resource_arn, tags)
def untag_resource(self, resource_arn, tag_keys): def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagging_service.untag_resource_using_names(resource_arn, tag_keys) self.tagging_service.untag_resource_using_names(resource_arn, tag_keys)

View File

@ -1,19 +1,19 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import timestreamwrite_backends from .models import timestreamwrite_backends, TimestreamWriteBackend
class TimestreamWriteResponse(BaseResponse): class TimestreamWriteResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="timestream-write") super().__init__(service_name="timestream-write")
@property @property
def timestreamwrite_backend(self): def timestreamwrite_backend(self) -> TimestreamWriteBackend:
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return timestreamwrite_backends[self.current_account][self.region] return timestreamwrite_backends[self.current_account][self.region]
def create_database(self): def create_database(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
kms_key_id = self._get_param("KmsKeyId") kms_key_id = self._get_param("KmsKeyId")
tags = self._get_param("Tags") tags = self._get_param("Tags")
@ -22,19 +22,19 @@ class TimestreamWriteResponse(BaseResponse):
) )
return json.dumps(dict(Database=database.description())) return json.dumps(dict(Database=database.description()))
def delete_database(self): def delete_database(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
self.timestreamwrite_backend.delete_database(database_name=database_name) self.timestreamwrite_backend.delete_database(database_name=database_name)
return "{}" return "{}"
def describe_database(self): def describe_database(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
database = self.timestreamwrite_backend.describe_database( database = self.timestreamwrite_backend.describe_database(
database_name=database_name database_name=database_name
) )
return json.dumps(dict(Database=database.description())) return json.dumps(dict(Database=database.description()))
def update_database(self): def update_database(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
kms_key_id = self._get_param("KmsKeyId") kms_key_id = self._get_param("KmsKeyId")
database = self.timestreamwrite_backend.update_database( database = self.timestreamwrite_backend.update_database(
@ -42,11 +42,11 @@ class TimestreamWriteResponse(BaseResponse):
) )
return json.dumps(dict(Database=database.description())) return json.dumps(dict(Database=database.description()))
def list_databases(self): def list_databases(self) -> str:
all_dbs = self.timestreamwrite_backend.list_databases() all_dbs = self.timestreamwrite_backend.list_databases()
return json.dumps(dict(Databases=[db.description() for db in all_dbs])) return json.dumps(dict(Databases=[db.description() for db in all_dbs]))
def create_table(self): def create_table(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName") table_name = self._get_param("TableName")
retention_properties = self._get_param("RetentionProperties") retention_properties = self._get_param("RetentionProperties")
@ -63,24 +63,24 @@ class TimestreamWriteResponse(BaseResponse):
) )
return json.dumps(dict(Table=table.description())) return json.dumps(dict(Table=table.description()))
def delete_table(self): def delete_table(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName") table_name = self._get_param("TableName")
self.timestreamwrite_backend.delete_table(database_name, table_name) self.timestreamwrite_backend.delete_table(database_name, table_name)
return "{}" return "{}"
def describe_table(self): def describe_table(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName") table_name = self._get_param("TableName")
table = self.timestreamwrite_backend.describe_table(database_name, table_name) table = self.timestreamwrite_backend.describe_table(database_name, table_name)
return json.dumps(dict(Table=table.description())) return json.dumps(dict(Table=table.description()))
def list_tables(self): def list_tables(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
tables = self.timestreamwrite_backend.list_tables(database_name) tables = self.timestreamwrite_backend.list_tables(database_name)
return json.dumps(dict(Tables=[t.description() for t in tables])) return json.dumps(dict(Tables=[t.description() for t in tables]))
def update_table(self): def update_table(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName") table_name = self._get_param("TableName")
retention_properties = self._get_param("RetentionProperties") retention_properties = self._get_param("RetentionProperties")
@ -95,7 +95,7 @@ class TimestreamWriteResponse(BaseResponse):
) )
return json.dumps(dict(Table=table.description())) return json.dumps(dict(Table=table.description()))
def write_records(self): def write_records(self) -> str:
database_name = self._get_param("DatabaseName") database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName") table_name = self._get_param("TableName")
records = self._get_param("Records") records = self._get_param("Records")
@ -109,22 +109,22 @@ class TimestreamWriteResponse(BaseResponse):
} }
return json.dumps(resp) return json.dumps(resp)
def describe_endpoints(self): def describe_endpoints(self) -> str:
resp = self.timestreamwrite_backend.describe_endpoints() resp = self.timestreamwrite_backend.describe_endpoints()
return json.dumps(resp) return json.dumps(resp)
def list_tags_for_resource(self): def list_tags_for_resource(self) -> str:
resource_arn = self._get_param("ResourceARN") resource_arn = self._get_param("ResourceARN")
tags = self.timestreamwrite_backend.list_tags_for_resource(resource_arn) tags = self.timestreamwrite_backend.list_tags_for_resource(resource_arn)
return json.dumps(tags) return json.dumps(tags)
def tag_resource(self): def tag_resource(self) -> str:
resource_arn = self._get_param("ResourceARN") resource_arn = self._get_param("ResourceARN")
tags = self._get_param("Tags") tags = self._get_param("Tags")
self.timestreamwrite_backend.tag_resource(resource_arn, tags) self.timestreamwrite_backend.tag_resource(resource_arn, tags)
return "{}" return "{}"
def untag_resource(self): def untag_resource(self) -> str:
resource_arn = self._get_param("ResourceARN") resource_arn = self._get_param("ResourceARN")
tag_keys = self._get_param("TagKeys") tag_keys = self._get_param("TagKeys")
self.timestreamwrite_backend.untag_resource(resource_arn, tag_keys) self.timestreamwrite_backend.untag_resource(resource_arn, tag_keys)

View File

@ -2,10 +2,10 @@ from moto.core.exceptions import JsonRESTError
class ConflictException(JsonRESTError): class ConflictException(JsonRESTError):
def __init__(self, message, **kwargs): def __init__(self, message: str):
super().__init__("ConflictException", message, **kwargs) super().__init__("ConflictException", message)
class BadRequestException(JsonRESTError): class BadRequestException(JsonRESTError):
def __init__(self, message, **kwargs): def __init__(self, message: str):
super().__init__("BadRequestException", message, **kwargs) super().__init__("BadRequestException", message)

View File

@ -1,4 +1,5 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api import state_manager from moto.moto_api import state_manager
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
@ -7,14 +8,14 @@ from .exceptions import ConflictException, BadRequestException
class BaseObject(BaseModel): class BaseObject(BaseModel):
def camelCase(self, key): def camelCase(self, key: str) -> str:
words = [] words = []
for word in key.split("_"): for word in key.split("_"):
words.append(word.title()) words.append(word.title())
return "".join(words) return "".join(words)
def gen_response_object(self): def gen_response_object(self) -> Dict[str, Any]:
response_object = dict() response_object: Dict[str, Any] = dict()
for key, value in self.__dict__.items(): for key, value in self.__dict__.items():
if "_" in key: if "_" in key:
response_object[self.camelCase(key)] = value response_object[self.camelCase(key)] = value
@ -23,30 +24,30 @@ class BaseObject(BaseModel):
return response_object return response_object
@property @property
def response_object(self): def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
return self.gen_response_object() return self.gen_response_object()
class FakeTranscriptionJob(BaseObject, ManagedState): class FakeTranscriptionJob(BaseObject, ManagedState):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region_name, region_name: str,
transcription_job_name, transcription_job_name: str,
language_code, language_code: Optional[str],
media_sample_rate_hertz, media_sample_rate_hertz: Optional[int],
media_format, media_format: Optional[str],
media, media: Dict[str, str],
output_bucket_name, output_bucket_name: Optional[str],
output_key, output_key: Optional[str],
output_encryption_kms_key_id, output_encryption_kms_key_id: Optional[str],
settings, settings: Optional[Dict[str, Any]],
model_settings, model_settings: Optional[Dict[str, Optional[str]]],
job_execution_settings, job_execution_settings: Optional[Dict[str, Any]],
content_redaction, content_redaction: Optional[Dict[str, Any]],
identify_language, identify_language: Optional[bool],
identify_multiple_languages, identify_multiple_languages: Optional[bool],
language_options, language_options: Optional[List[str]],
): ):
ManagedState.__init__( ManagedState.__init__(
self, self,
@ -61,12 +62,13 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self._region_name = region_name self._region_name = region_name
self.transcription_job_name = transcription_job_name self.transcription_job_name = transcription_job_name
self.language_code = language_code self.language_code = language_code
self.language_codes = None self.language_codes: Optional[List[Dict[str, Any]]] = None
self.media_sample_rate_hertz = media_sample_rate_hertz self.media_sample_rate_hertz = media_sample_rate_hertz
self.media_format = media_format self.media_format = media_format
self.media = media self.media = media
self.transcript = None self.transcript: Optional[Dict[str, str]] = None
self.start_time = self.completion_time = None self.start_time: Optional[str] = None
self.completion_time: Optional[str] = None
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.failure_reason = None self.failure_reason = None
self.settings = settings or { self.settings = settings or {
@ -86,7 +88,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self.identify_language = identify_language self.identify_language = identify_language
self.identify_multiple_languages = identify_multiple_languages self.identify_multiple_languages = identify_multiple_languages
self.language_options = language_options self.language_options = language_options
self.identified_language_score = (None,) self.identified_language_score: Optional[float] = None
self._output_bucket_name = output_bucket_name self._output_bucket_name = output_bucket_name
self.output_key = output_key self.output_key = output_key
self._output_encryption_kms_key_id = output_encryption_kms_key_id self._output_encryption_kms_key_id = output_encryption_kms_key_id
@ -94,7 +96,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET" "CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
) )
def response_object(self, response_type): def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
response_field_dict = { response_field_dict = {
"CREATE": [ "CREATE": [
"TranscriptionJobName", "TranscriptionJobName",
@ -162,7 +164,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
if k in response_fields and v is not None and v != [None] if k in response_fields and v is not None and v != [None]
} }
def advance(self): def advance(self) -> None:
old_status = self.status old_status = self.status
super().advance() super().advance()
new_status = self.status new_status = self.status
@ -191,20 +193,20 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self.identified_language_score = 0.999645948 self.identified_language_score = 0.999645948
# Identify first two languages passed in language_options # Identify first two languages passed in language_options
# If none is set, default to "en-US" # If none is set, default to "en-US"
self.language_codes = [] self.language_codes: List[Dict[str, Any]] = [] # type: ignore[no-redef]
if self.language_options is None or len(self.language_options) == 0: if self.language_options is None or len(self.language_options) == 0:
self.language_codes.append( self.language_codes.append( # type: ignore
{"LanguageCode": "en-US", "DurationInSeconds": 123.0} {"LanguageCode": "en-US", "DurationInSeconds": 123.0}
) )
else: else:
self.language_codes.append( self.language_codes.append( # type: ignore
{ {
"LanguageCode": self.language_options[0], "LanguageCode": self.language_options[0],
"DurationInSeconds": 123.0, "DurationInSeconds": 123.0,
} }
) )
if len(self.language_options) > 1: if len(self.language_options) > 1:
self.language_codes.append( self.language_codes.append( # type: ignore
{ {
"LanguageCode": self.language_options[1], "LanguageCode": self.language_options[1],
"DurationInSeconds": 321.0, "DurationInSeconds": 321.0,
@ -229,12 +231,12 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
class FakeVocabulary(BaseObject, ManagedState): class FakeVocabulary(BaseObject, ManagedState):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region_name, region_name: str,
vocabulary_name, vocabulary_name: str,
language_code, language_code: str,
phrases, phrases: Optional[List[str]],
vocabulary_file_uri, vocabulary_file_uri: Optional[str],
): ):
# Configured ManagedState # Configured ManagedState
super().__init__( super().__init__(
@ -247,11 +249,11 @@ class FakeVocabulary(BaseObject, ManagedState):
self.language_code = language_code self.language_code = language_code
self.phrases = phrases self.phrases = phrases
self.vocabulary_file_uri = vocabulary_file_uri self.vocabulary_file_uri = vocabulary_file_uri
self.last_modified_time = None self.last_modified_time: Optional[str] = None
self.failure_reason = None self.failure_reason = None
self.download_uri = f"https://s3.{region_name}.amazonaws.com/aws-transcribe-dictionary-model-{region_name}-prod/{account_id}/{vocabulary_name}/{mock_random.uuid4()}/input.txt" self.download_uri = f"https://s3.{region_name}.amazonaws.com/aws-transcribe-dictionary-model-{region_name}-prod/{account_id}/{vocabulary_name}/{mock_random.uuid4()}/input.txt"
def response_object(self, response_type): def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
response_field_dict = { response_field_dict = {
"CREATE": [ "CREATE": [
"VocabularyName", "VocabularyName",
@ -284,7 +286,7 @@ class FakeVocabulary(BaseObject, ManagedState):
if k in response_fields and v is not None and v != [None] if k in response_fields and v is not None and v != [None]
} }
def advance(self): def advance(self) -> None:
old_status = self.status old_status = self.status
super().advance() super().advance()
new_status = self.status new_status = self.status
@ -296,17 +298,17 @@ class FakeVocabulary(BaseObject, ManagedState):
class FakeMedicalTranscriptionJob(BaseObject, ManagedState): class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
def __init__( def __init__(
self, self,
region_name, region_name: str,
medical_transcription_job_name, medical_transcription_job_name: str,
language_code, language_code: str,
media_sample_rate_hertz, media_sample_rate_hertz: Optional[int],
media_format, media_format: Optional[str],
media, media: Dict[str, str],
output_bucket_name, output_bucket_name: str,
output_encryption_kms_key_id, output_encryption_kms_key_id: Optional[str],
settings, settings: Optional[Dict[str, Any]],
specialty, specialty: str,
job_type, job_type: str,
): ):
ManagedState.__init__( ManagedState.__init__(
self, self,
@ -323,8 +325,9 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
self.media_sample_rate_hertz = media_sample_rate_hertz self.media_sample_rate_hertz = media_sample_rate_hertz
self.media_format = media_format self.media_format = media_format
self.media = media self.media = media
self.transcript = None self.transcript: Optional[Dict[str, str]] = None
self.start_time = self.completion_time = None self.start_time: Optional[str] = None
self.completion_time: Optional[str] = None
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.failure_reason = None self.failure_reason = None
self.settings = settings or { self.settings = settings or {
@ -337,7 +340,7 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
self._output_encryption_kms_key_id = output_encryption_kms_key_id self._output_encryption_kms_key_id = output_encryption_kms_key_id
self.output_location_type = "CUSTOMER_BUCKET" self.output_location_type = "CUSTOMER_BUCKET"
def response_object(self, response_type): def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
response_field_dict = { response_field_dict = {
"CREATE": [ "CREATE": [
"MedicalTranscriptionJobName", "MedicalTranscriptionJobName",
@ -396,7 +399,7 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
if k in response_fields and v is not None and v != [None] if k in response_fields and v is not None and v != [None]
} }
def advance(self): def advance(self) -> None:
old_status = self.status old_status = self.status
super().advance() super().advance()
new_status = self.status new_status = self.status
@ -425,11 +428,11 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
class FakeMedicalVocabulary(FakeVocabulary): class FakeMedicalVocabulary(FakeVocabulary):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region_name, region_name: str,
vocabulary_name, vocabulary_name: str,
language_code, language_code: str,
vocabulary_file_uri, vocabulary_file_uri: Optional[str],
): ):
super().__init__( super().__init__(
account_id, account_id,
@ -450,12 +453,12 @@ class FakeMedicalVocabulary(FakeVocabulary):
class TranscribeBackend(BaseBackend): class TranscribeBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.medical_transcriptions = {} self.medical_transcriptions: Dict[str, FakeMedicalTranscriptionJob] = {}
self.transcriptions = {} self.transcriptions: Dict[str, FakeTranscriptionJob] = {}
self.medical_vocabularies = {} self.medical_vocabularies: Dict[str, FakeMedicalVocabulary] = {}
self.vocabularies = {} self.vocabularies: Dict[str, FakeVocabulary] = {}
state_manager.register_default_transition( state_manager.register_default_transition(
"transcribe::vocabulary", transition={"progression": "manual", "times": 1} "transcribe::vocabulary", transition={"progression": "manual", "times": 1}
@ -474,7 +477,9 @@ class TranscribeBackend(BaseBackend):
) )
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint services.""" """Default VPC endpoint services."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "transcribe" service_region, zones, "transcribe"
@ -482,15 +487,29 @@ class TranscribeBackend(BaseBackend):
service_region, zones, "transcribestreaming" service_region, zones, "transcribestreaming"
) )
def start_transcription_job(self, **kwargs): def start_transcription_job(
self,
name = kwargs.get("transcription_job_name") transcription_job_name: str,
if name in self.transcriptions: language_code: Optional[str],
media_sample_rate_hertz: Optional[int],
media_format: Optional[str],
media: Dict[str, str],
output_bucket_name: Optional[str],
output_key: Optional[str],
output_encryption_kms_key_id: Optional[str],
settings: Optional[Dict[str, Any]],
model_settings: Optional[Dict[str, Optional[str]]],
job_execution_settings: Optional[Dict[str, Any]],
content_redaction: Optional[Dict[str, Any]],
identify_language: Optional[bool],
identify_multiple_languages: Optional[bool],
language_options: Optional[List[str]],
) -> Dict[str, Any]:
if transcription_job_name in self.transcriptions:
raise ConflictException( raise ConflictException(
message="The requested job name already exists. Use a different job name." message="The requested job name already exists. Use a different job name."
) )
settings = kwargs.get("settings")
vocabulary_name = settings.get("VocabularyName") if settings else None vocabulary_name = settings.get("VocabularyName") if settings else None
if vocabulary_name and vocabulary_name not in self.vocabularies: if vocabulary_name and vocabulary_name not in self.vocabularies:
raise BadRequestException( raise BadRequestException(
@ -501,36 +520,45 @@ class TranscribeBackend(BaseBackend):
transcription_job_object = FakeTranscriptionJob( transcription_job_object = FakeTranscriptionJob(
account_id=self.account_id, account_id=self.account_id,
region_name=self.region_name, region_name=self.region_name,
transcription_job_name=name, transcription_job_name=transcription_job_name,
language_code=kwargs.get("language_code"), language_code=language_code,
media_sample_rate_hertz=kwargs.get("media_sample_rate_hertz"), media_sample_rate_hertz=media_sample_rate_hertz,
media_format=kwargs.get("media_format"), media_format=media_format,
media=kwargs.get("media"), media=media,
output_bucket_name=kwargs.get("output_bucket_name"), output_bucket_name=output_bucket_name,
output_key=kwargs.get("output_key"), output_key=output_key,
output_encryption_kms_key_id=kwargs.get("output_encryption_kms_key_id"), output_encryption_kms_key_id=output_encryption_kms_key_id,
settings=settings, settings=settings,
model_settings=kwargs.get("model_settings"), model_settings=model_settings,
job_execution_settings=kwargs.get("job_execution_settings"), job_execution_settings=job_execution_settings,
content_redaction=kwargs.get("content_redaction"), content_redaction=content_redaction,
identify_language=kwargs.get("identify_language"), identify_language=identify_language,
identify_multiple_languages=kwargs.get("identify_multiple_languages"), identify_multiple_languages=identify_multiple_languages,
language_options=kwargs.get("language_options"), language_options=language_options,
) )
self.transcriptions[name] = transcription_job_object self.transcriptions[transcription_job_name] = transcription_job_object
return transcription_job_object.response_object("CREATE") return transcription_job_object.response_object("CREATE")
def start_medical_transcription_job(self, **kwargs): def start_medical_transcription_job(
self,
medical_transcription_job_name: str,
language_code: str,
media_sample_rate_hertz: Optional[int],
media_format: Optional[str],
media: Dict[str, str],
output_bucket_name: str,
output_encryption_kms_key_id: Optional[str],
settings: Optional[Dict[str, Any]],
specialty: str,
type_: str,
) -> Dict[str, Any]:
name = kwargs.get("medical_transcription_job_name") if medical_transcription_job_name in self.medical_transcriptions:
if name in self.medical_transcriptions:
raise ConflictException( raise ConflictException(
message="The requested job name already exists. Use a different job name." message="The requested job name already exists. Use a different job name."
) )
settings = kwargs.get("settings")
vocabulary_name = settings.get("VocabularyName") if settings else None vocabulary_name = settings.get("VocabularyName") if settings else None
if vocabulary_name and vocabulary_name not in self.medical_vocabularies: if vocabulary_name and vocabulary_name not in self.medical_vocabularies:
raise BadRequestException( raise BadRequestException(
@ -540,23 +568,25 @@ class TranscribeBackend(BaseBackend):
transcription_job_object = FakeMedicalTranscriptionJob( transcription_job_object = FakeMedicalTranscriptionJob(
region_name=self.region_name, region_name=self.region_name,
medical_transcription_job_name=name, medical_transcription_job_name=medical_transcription_job_name,
language_code=kwargs.get("language_code"), language_code=language_code,
media_sample_rate_hertz=kwargs.get("media_sample_rate_hertz"), media_sample_rate_hertz=media_sample_rate_hertz,
media_format=kwargs.get("media_format"), media_format=media_format,
media=kwargs.get("media"), media=media,
output_bucket_name=kwargs.get("output_bucket_name"), output_bucket_name=output_bucket_name,
output_encryption_kms_key_id=kwargs.get("output_encryption_kms_key_id"), output_encryption_kms_key_id=output_encryption_kms_key_id,
settings=settings, settings=settings,
specialty=kwargs.get("specialty"), specialty=specialty,
job_type=kwargs.get("type"), job_type=type_,
) )
self.medical_transcriptions[name] = transcription_job_object self.medical_transcriptions[
medical_transcription_job_name
] = transcription_job_object
return transcription_job_object.response_object("CREATE") return transcription_job_object.response_object("CREATE")
def get_transcription_job(self, transcription_job_name): def get_transcription_job(self, transcription_job_name: str) -> Dict[str, Any]:
try: try:
job = self.transcriptions[transcription_job_name] job = self.transcriptions[transcription_job_name]
job.advance() # Fakes advancement through statuses. job.advance() # Fakes advancement through statuses.
@ -567,7 +597,9 @@ class TranscribeBackend(BaseBackend):
"Check the job name and try your request again." "Check the job name and try your request again."
) )
def get_medical_transcription_job(self, medical_transcription_job_name): def get_medical_transcription_job(
self, medical_transcription_job_name: str
) -> Dict[str, Any]:
try: try:
job = self.medical_transcriptions[medical_transcription_job_name] job = self.medical_transcriptions[medical_transcription_job_name]
job.advance() # Fakes advancement through statuses. job.advance() # Fakes advancement through statuses.
@ -578,7 +610,7 @@ class TranscribeBackend(BaseBackend):
"Check the job name and try your request again." "Check the job name and try your request again."
) )
def delete_transcription_job(self, transcription_job_name): def delete_transcription_job(self, transcription_job_name: str) -> None:
try: try:
del self.transcriptions[transcription_job_name] del self.transcriptions[transcription_job_name]
except KeyError: except KeyError:
@ -587,7 +619,9 @@ class TranscribeBackend(BaseBackend):
"Check the job name and try your request again." "Check the job name and try your request again."
) )
def delete_medical_transcription_job(self, medical_transcription_job_name): def delete_medical_transcription_job(
self, medical_transcription_job_name: str
) -> None:
try: try:
del self.medical_transcriptions[medical_transcription_job_name] del self.medical_transcriptions[medical_transcription_job_name]
except KeyError: except KeyError:
@ -597,8 +631,12 @@ class TranscribeBackend(BaseBackend):
) )
def list_transcription_jobs( def list_transcription_jobs(
self, state_equals, job_name_contains, next_token, max_results self,
): state_equals: str,
job_name_contains: str,
next_token: str,
max_results: int,
) -> Dict[str, Any]:
jobs = list(self.transcriptions.values()) jobs = list(self.transcriptions.values())
if state_equals: if state_equals:
@ -615,7 +653,7 @@ class TranscribeBackend(BaseBackend):
) # Arbitrarily selected... ) # Arbitrarily selected...
jobs_paginated = jobs[start_offset:end_offset] jobs_paginated = jobs[start_offset:end_offset]
response = { response: Dict[str, Any] = {
"TranscriptionJobSummaries": [ "TranscriptionJobSummaries": [
job.response_object("LIST") for job in jobs_paginated job.response_object("LIST") for job in jobs_paginated
] ]
@ -627,8 +665,8 @@ class TranscribeBackend(BaseBackend):
return response return response
def list_medical_transcription_jobs( def list_medical_transcription_jobs(
self, status, job_name_contains, next_token, max_results self, status: str, job_name_contains: str, next_token: str, max_results: int
): ) -> Dict[str, Any]:
jobs = list(self.medical_transcriptions.values()) jobs = list(self.medical_transcriptions.values())
if status: if status:
@ -647,7 +685,7 @@ class TranscribeBackend(BaseBackend):
) # Arbitrarily selected... ) # Arbitrarily selected...
jobs_paginated = jobs[start_offset:end_offset] jobs_paginated = jobs[start_offset:end_offset]
response = { response: Dict[str, Any] = {
"MedicalTranscriptionJobSummaries": [ "MedicalTranscriptionJobSummaries": [
job.response_object("LIST") for job in jobs_paginated job.response_object("LIST") for job in jobs_paginated
] ]
@ -658,12 +696,13 @@ class TranscribeBackend(BaseBackend):
response["Status"] = status response["Status"] = status
return response return response
def create_vocabulary(self, **kwargs): def create_vocabulary(
self,
vocabulary_name = kwargs.get("vocabulary_name") vocabulary_name: str,
language_code = kwargs.get("language_code") language_code: str,
phrases = kwargs.get("phrases") phrases: Optional[List[str]],
vocabulary_file_uri = kwargs.get("vocabulary_file_uri") vocabulary_file_uri: Optional[str],
) -> Dict[str, Any]:
if ( if (
phrases is not None phrases is not None
and vocabulary_file_uri is not None and vocabulary_file_uri is not None
@ -698,12 +737,12 @@ class TranscribeBackend(BaseBackend):
return vocabulary_object.response_object("CREATE") return vocabulary_object.response_object("CREATE")
def create_medical_vocabulary(self, **kwargs): def create_medical_vocabulary(
self,
vocabulary_name = kwargs.get("vocabulary_name") vocabulary_name: str,
language_code = kwargs.get("language_code") language_code: str,
vocabulary_file_uri = kwargs.get("vocabulary_file_uri") vocabulary_file_uri: Optional[str],
) -> Dict[str, Any]:
if vocabulary_name in self.medical_vocabularies: if vocabulary_name in self.medical_vocabularies:
raise ConflictException( raise ConflictException(
message="The requested vocabulary name already exists. " message="The requested vocabulary name already exists. "
@ -722,7 +761,7 @@ class TranscribeBackend(BaseBackend):
return medical_vocabulary_object.response_object("CREATE") return medical_vocabulary_object.response_object("CREATE")
def get_vocabulary(self, vocabulary_name): def get_vocabulary(self, vocabulary_name: str) -> Dict[str, Any]:
try: try:
job = self.vocabularies[vocabulary_name] job = self.vocabularies[vocabulary_name]
job.advance() # Fakes advancement through statuses. job.advance() # Fakes advancement through statuses.
@ -733,7 +772,7 @@ class TranscribeBackend(BaseBackend):
"Check the vocabulary name and try your request again." "Check the vocabulary name and try your request again."
) )
def get_medical_vocabulary(self, vocabulary_name): def get_medical_vocabulary(self, vocabulary_name: str) -> Dict[str, Any]:
try: try:
job = self.medical_vocabularies[vocabulary_name] job = self.medical_vocabularies[vocabulary_name]
job.advance() # Fakes advancement through statuses. job.advance() # Fakes advancement through statuses.
@ -744,7 +783,7 @@ class TranscribeBackend(BaseBackend):
"Check the vocabulary name and try your request again." "Check the vocabulary name and try your request again."
) )
def delete_vocabulary(self, vocabulary_name): def delete_vocabulary(self, vocabulary_name: str) -> None:
try: try:
del self.vocabularies[vocabulary_name] del self.vocabularies[vocabulary_name]
except KeyError: except KeyError:
@ -752,7 +791,7 @@ class TranscribeBackend(BaseBackend):
message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again." message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again."
) )
def delete_medical_vocabulary(self, vocabulary_name): def delete_medical_vocabulary(self, vocabulary_name: str) -> None:
try: try:
del self.medical_vocabularies[vocabulary_name] del self.medical_vocabularies[vocabulary_name]
except KeyError: except KeyError:
@ -760,7 +799,9 @@ class TranscribeBackend(BaseBackend):
message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again." message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again."
) )
def list_vocabularies(self, state_equals, name_contains, next_token, max_results): def list_vocabularies(
self, state_equals: str, name_contains: str, next_token: str, max_results: int
) -> Dict[str, Any]:
vocabularies = list(self.vocabularies.values()) vocabularies = list(self.vocabularies.values())
if state_equals: if state_equals:
@ -783,7 +824,7 @@ class TranscribeBackend(BaseBackend):
) # Arbitrarily selected... ) # Arbitrarily selected...
vocabularies_paginated = vocabularies[start_offset:end_offset] vocabularies_paginated = vocabularies[start_offset:end_offset]
response = { response: Dict[str, Any] = {
"Vocabularies": [ "Vocabularies": [
vocabulary.response_object("LIST") vocabulary.response_object("LIST")
for vocabulary in vocabularies_paginated for vocabulary in vocabularies_paginated
@ -796,8 +837,8 @@ class TranscribeBackend(BaseBackend):
return response return response
def list_medical_vocabularies( def list_medical_vocabularies(
self, state_equals, name_contains, next_token, max_results self, state_equals: str, name_contains: str, next_token: str, max_results: int
): ) -> Dict[str, Any]:
vocabularies = list(self.medical_vocabularies.values()) vocabularies = list(self.medical_vocabularies.values())
if state_equals: if state_equals:
@ -820,7 +861,7 @@ class TranscribeBackend(BaseBackend):
) # Arbitrarily selected... ) # Arbitrarily selected...
vocabularies_paginated = vocabularies[start_offset:end_offset] vocabularies_paginated = vocabularies[start_offset:end_offset]
response = { response: Dict[str, Any] = {
"Vocabularies": [ "Vocabularies": [
vocabulary.response_object("LIST") vocabulary.response_object("LIST")
for vocabulary in vocabularies_paginated for vocabulary in vocabularies_paginated

View File

@ -2,26 +2,19 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.utilities.aws_headers import amzn_request_id from moto.utilities.aws_headers import amzn_request_id
from .models import transcribe_backends from .models import transcribe_backends, TranscribeBackend
class TranscribeResponse(BaseResponse): class TranscribeResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="transcribe") super().__init__(service_name="transcribe")
@property @property
def transcribe_backend(self): def transcribe_backend(self) -> TranscribeBackend:
return transcribe_backends[self.current_account][self.region] return transcribe_backends[self.current_account][self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
@amzn_request_id @amzn_request_id
def start_transcription_job(self): def start_transcription_job(self) -> str:
name = self._get_param("TranscriptionJobName") name = self._get_param("TranscriptionJobName")
response = self.transcribe_backend.start_transcription_job( response = self.transcribe_backend.start_transcription_job(
transcription_job_name=name, transcription_job_name=name,
@ -43,7 +36,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def start_medical_transcription_job(self): def start_medical_transcription_job(self) -> str:
name = self._get_param("MedicalTranscriptionJobName") name = self._get_param("MedicalTranscriptionJobName")
response = self.transcribe_backend.start_medical_transcription_job( response = self.transcribe_backend.start_medical_transcription_job(
medical_transcription_job_name=name, medical_transcription_job_name=name,
@ -55,12 +48,12 @@ class TranscribeResponse(BaseResponse):
output_encryption_kms_key_id=self._get_param("OutputEncryptionKMSKeyId"), output_encryption_kms_key_id=self._get_param("OutputEncryptionKMSKeyId"),
settings=self._get_param("Settings"), settings=self._get_param("Settings"),
specialty=self._get_param("Specialty"), specialty=self._get_param("Specialty"),
type=self._get_param("Type"), type_=self._get_param("Type"),
) )
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def list_transcription_jobs(self): def list_transcription_jobs(self) -> str:
state_equals = self._get_param("Status") state_equals = self._get_param("Status")
job_name_contains = self._get_param("JobNameContains") job_name_contains = self._get_param("JobNameContains")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -75,7 +68,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def list_medical_transcription_jobs(self): def list_medical_transcription_jobs(self) -> str:
status = self._get_param("Status") status = self._get_param("Status")
job_name_contains = self._get_param("JobNameContains") job_name_contains = self._get_param("JobNameContains")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -90,7 +83,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def get_transcription_job(self): def get_transcription_job(self) -> str:
transcription_job_name = self._get_param("TranscriptionJobName") transcription_job_name = self._get_param("TranscriptionJobName")
response = self.transcribe_backend.get_transcription_job( response = self.transcribe_backend.get_transcription_job(
transcription_job_name=transcription_job_name transcription_job_name=transcription_job_name
@ -98,7 +91,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def get_medical_transcription_job(self): def get_medical_transcription_job(self) -> str:
medical_transcription_job_name = self._get_param("MedicalTranscriptionJobName") medical_transcription_job_name = self._get_param("MedicalTranscriptionJobName")
response = self.transcribe_backend.get_medical_transcription_job( response = self.transcribe_backend.get_medical_transcription_job(
medical_transcription_job_name=medical_transcription_job_name medical_transcription_job_name=medical_transcription_job_name
@ -106,23 +99,23 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def delete_transcription_job(self): def delete_transcription_job(self) -> str:
transcription_job_name = self._get_param("TranscriptionJobName") transcription_job_name = self._get_param("TranscriptionJobName")
response = self.transcribe_backend.delete_transcription_job( self.transcribe_backend.delete_transcription_job(
transcription_job_name=transcription_job_name transcription_job_name=transcription_job_name
) )
return json.dumps(response) return "{}"
@amzn_request_id @amzn_request_id
def delete_medical_transcription_job(self): def delete_medical_transcription_job(self) -> str:
medical_transcription_job_name = self._get_param("MedicalTranscriptionJobName") medical_transcription_job_name = self._get_param("MedicalTranscriptionJobName")
response = self.transcribe_backend.delete_medical_transcription_job( self.transcribe_backend.delete_medical_transcription_job(
medical_transcription_job_name=medical_transcription_job_name medical_transcription_job_name=medical_transcription_job_name
) )
return json.dumps(response) return "{}"
@amzn_request_id @amzn_request_id
def create_vocabulary(self): def create_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName") vocabulary_name = self._get_param("VocabularyName")
language_code = self._get_param("LanguageCode") language_code = self._get_param("LanguageCode")
phrases = self._get_param("Phrases") phrases = self._get_param("Phrases")
@ -136,7 +129,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def create_medical_vocabulary(self): def create_medical_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName") vocabulary_name = self._get_param("VocabularyName")
language_code = self._get_param("LanguageCode") language_code = self._get_param("LanguageCode")
vocabulary_file_uri = self._get_param("VocabularyFileUri") vocabulary_file_uri = self._get_param("VocabularyFileUri")
@ -148,7 +141,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def get_vocabulary(self): def get_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName") vocabulary_name = self._get_param("VocabularyName")
response = self.transcribe_backend.get_vocabulary( response = self.transcribe_backend.get_vocabulary(
vocabulary_name=vocabulary_name vocabulary_name=vocabulary_name
@ -156,7 +149,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def get_medical_vocabulary(self): def get_medical_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName") vocabulary_name = self._get_param("VocabularyName")
response = self.transcribe_backend.get_medical_vocabulary( response = self.transcribe_backend.get_medical_vocabulary(
vocabulary_name=vocabulary_name vocabulary_name=vocabulary_name
@ -164,7 +157,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def list_vocabularies(self): def list_vocabularies(self) -> str:
state_equals = self._get_param("StateEquals") state_equals = self._get_param("StateEquals")
name_contains = self._get_param("NameContains") name_contains = self._get_param("NameContains")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -179,7 +172,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def list_medical_vocabularies(self): def list_medical_vocabularies(self) -> str:
state_equals = self._get_param("StateEquals") state_equals = self._get_param("StateEquals")
name_contains = self._get_param("NameContains") name_contains = self._get_param("NameContains")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -194,17 +187,15 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
@amzn_request_id @amzn_request_id
def delete_vocabulary(self): def delete_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName") vocabulary_name = self._get_param("VocabularyName")
response = self.transcribe_backend.delete_vocabulary( self.transcribe_backend.delete_vocabulary(vocabulary_name=vocabulary_name)
vocabulary_name=vocabulary_name return "{}"
)
return json.dumps(response)
@amzn_request_id @amzn_request_id
def delete_medical_vocabulary(self): def delete_medical_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName") vocabulary_name = self._get_param("VocabularyName")
response = self.transcribe_backend.delete_medical_vocabulary( self.transcribe_backend.delete_medical_vocabulary(
vocabulary_name=vocabulary_name vocabulary_name=vocabulary_name
) )
return json.dumps(response) return "{}"

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u* files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u*,moto/t*
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract