Techdebt: MyPy T (#6270)

This commit is contained in:
Bert Blommers 2023-04-29 22:21:00 +00:00 committed by GitHub
parent e5e1521523
commit 7d6afe4b67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 315 additions and 266 deletions

View File

@ -5,16 +5,16 @@ from moto.core.exceptions import JsonRESTError
class InvalidJobIdException(JsonRESTError):
code = 400
def __init__(self):
super().__init__(__class__.__name__, "An invalid job identifier was passed.")
def __init__(self) -> None:
super().__init__(__class__.__name__, "An invalid job identifier was passed.") # type: ignore
class InvalidS3ObjectException(JsonRESTError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
__class__.__name__,
__class__.__name__, # type: ignore
"Amazon Textract is unable to access the S3 object that's specified in the request.",
)
@ -22,8 +22,8 @@ class InvalidS3ObjectException(JsonRESTError):
class InvalidParameterException(JsonRESTError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
__class__.__name__,
__class__.__name__, # type: ignore
"An input parameter violated a constraint. For example, in synchronous operations, an InvalidParameterException exception occurs when neither of the S3Object or Bytes values are supplied in the Document request parameter. Validate your parameter before calling the API operation again.",
)

View File

@ -1,6 +1,5 @@
"""TextractBackend class with methods for supported APIs."""
from collections import defaultdict
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random
@ -16,10 +15,10 @@ class TextractJobStatus:
class TextractJob(BaseModel):
def __init__(self, job):
def __init__(self, job: Dict[str, Any]):
self.job = job
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return self.job
@ -28,13 +27,13 @@ class TextractBackend(BaseBackend):
JOB_STATUS = TextractJobStatus.succeeded
PAGES = {"Pages": mock_random.randint(5, 500)}
BLOCKS = []
BLOCKS: List[Dict[str, Any]] = []
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.async_text_detection_jobs = defaultdict()
self.async_text_detection_jobs: Dict[str, TextractJob] = defaultdict()
def get_document_text_detection(self, job_id):
def get_document_text_detection(self, job_id: str) -> TextractJob:
"""
Pagination has not yet been implemented
"""
@ -43,7 +42,7 @@ class TextractBackend(BaseBackend):
raise InvalidJobIdException()
return job
def start_document_text_detection(self, document_location):
def start_document_text_detection(self, document_location: str) -> str:
"""
The following parameters have not yet been implemented: ClientRequestToken, JobTag, NotificationChannel, OutputConfig, KmsKeyID
"""

View File

@ -2,27 +2,27 @@
import json
from moto.core.responses import BaseResponse
from .models import textract_backends
from .models import textract_backends, TextractBackend
class TextractResponse(BaseResponse):
"""Handler for Textract requests and responses."""
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="textract")
@property
def textract_backend(self):
def textract_backend(self) -> TextractBackend:
"""Return backend instance specific for this region."""
return textract_backends[self.current_account][self.region]
def get_document_text_detection(self):
def get_document_text_detection(self) -> str:
params = json.loads(self.body)
job_id = params.get("JobId")
job = self.textract_backend.get_document_text_detection(job_id=job_id).to_dict()
return json.dumps(job)
def start_document_text_detection(self):
def start_document_text_detection(self) -> str:
params = json.loads(self.body)
document_location = params.get("DocumentLocation")
job_id = self.textract_backend.start_document_text_detection(

View File

@ -5,5 +5,5 @@ from moto.core.exceptions import JsonRESTError
class ResourceNotFound(JsonRESTError):
error_type = "com.amazonaws.timestream.v20181101#ResourceNotFoundException"
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__(ResourceNotFound.error_type, msg)

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Iterable
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.utilities.tagging_service import TaggingService
from .exceptions import ResourceNotFound
@ -6,12 +7,12 @@ from .exceptions import ResourceNotFound
class TimestreamTable(BaseModel):
def __init__(
self,
account_id,
region_name,
table_name,
db_name,
retention_properties,
magnetic_store_write_properties,
account_id: str,
region_name: str,
table_name: str,
db_name: str,
retention_properties: Dict[str, int],
magnetic_store_write_properties: Dict[str, Any],
):
self.region_name = region_name
self.name = table_name
@ -21,18 +22,22 @@ class TimestreamTable(BaseModel):
"MagneticStoreRetentionPeriodInDays": 123,
}
self.magnetic_store_write_properties = magnetic_store_write_properties or {}
self.records = []
self.records: List[Dict[str, Any]] = []
self.arn = f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.db_name}/table/{self.name}"
def update(self, retention_properties, magnetic_store_write_properties):
def update(
self,
retention_properties: Dict[str, int],
magnetic_store_write_properties: Dict[str, Any],
) -> None:
self.retention_properties = retention_properties
if magnetic_store_write_properties is not None:
self.magnetic_store_write_properties = magnetic_store_write_properties
def write_records(self, records):
def write_records(self, records: List[Dict[str, Any]]) -> None:
self.records.extend(records)
def description(self):
def description(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"TableName": self.name,
@ -44,7 +49,9 @@ class TimestreamTable(BaseModel):
class TimestreamDatabase(BaseModel):
def __init__(self, account_id, region_name, database_name, kms_key_id):
def __init__(
self, account_id: str, region_name: str, database_name: str, kms_key_id: str
):
self.account_id = account_id
self.region_name = region_name
self.name = database_name
@ -54,14 +61,17 @@ class TimestreamDatabase(BaseModel):
self.arn = (
f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.name}"
)
self.tables = dict()
self.tables: Dict[str, TimestreamTable] = dict()
def update(self, kms_key_id):
def update(self, kms_key_id: str) -> None:
self.kms_key_id = kms_key_id
def create_table(
self, table_name, retention_properties, magnetic_store_write_properties
):
self,
table_name: str,
retention_properties: Dict[str, int],
magnetic_store_write_properties: Dict[str, Any],
) -> TimestreamTable:
table = TimestreamTable(
account_id=self.account_id,
region_name=self.region_name,
@ -74,8 +84,11 @@ class TimestreamDatabase(BaseModel):
return table
def update_table(
self, table_name, retention_properties, magnetic_store_write_properties
):
self,
table_name: str,
retention_properties: Dict[str, int],
magnetic_store_write_properties: Dict[str, Any],
) -> TimestreamTable:
table = self.tables[table_name]
table.update(
retention_properties=retention_properties,
@ -83,18 +96,18 @@ class TimestreamDatabase(BaseModel):
)
return table
def delete_table(self, table_name):
def delete_table(self, table_name: str) -> None:
self.tables.pop(table_name, None)
def describe_table(self, table_name):
def describe_table(self, table_name: str) -> TimestreamTable:
if table_name not in self.tables:
raise ResourceNotFound(f"The table {table_name} does not exist.")
return self.tables[table_name]
def list_tables(self):
def list_tables(self) -> Iterable[TimestreamTable]:
return self.tables.values()
def description(self):
def description(self) -> Dict[str, Any]:
return {
"Arn": self.arn,
"DatabaseName": self.name,
@ -117,12 +130,14 @@ class TimestreamWriteBackend(BaseBackend):
"""
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.databases = dict()
self.databases: Dict[str, TimestreamDatabase] = dict()
self.tagging_service = TaggingService()
def create_database(self, database_name, kms_key_id, tags):
def create_database(
self, database_name: str, kms_key_id: str, tags: List[Dict[str, str]]
) -> TimestreamDatabase:
database = TimestreamDatabase(
self.account_id, self.region_name, database_name, kms_key_id
)
@ -130,30 +145,32 @@ class TimestreamWriteBackend(BaseBackend):
self.tagging_service.tag_resource(database.arn, tags)
return database
def delete_database(self, database_name):
def delete_database(self, database_name: str) -> None:
del self.databases[database_name]
def describe_database(self, database_name):
def describe_database(self, database_name: str) -> TimestreamDatabase:
if database_name not in self.databases:
raise ResourceNotFound(f"The database {database_name} does not exist.")
return self.databases[database_name]
def list_databases(self):
def list_databases(self) -> Iterable[TimestreamDatabase]:
return self.databases.values()
def update_database(self, database_name, kms_key_id):
def update_database(
self, database_name: str, kms_key_id: str
) -> TimestreamDatabase:
database = self.databases[database_name]
database.update(kms_key_id=kms_key_id)
return database
def create_table(
self,
database_name,
table_name,
retention_properties,
tags,
magnetic_store_write_properties,
):
database_name: str,
table_name: str,
retention_properties: Dict[str, int],
tags: List[Dict[str, str]],
magnetic_store_write_properties: Dict[str, Any],
) -> TimestreamTable:
database = self.describe_database(database_name)
table = database.create_table(
table_name, retention_properties, magnetic_store_write_properties
@ -161,39 +178,38 @@ class TimestreamWriteBackend(BaseBackend):
self.tagging_service.tag_resource(table.arn, tags)
return table
def delete_table(self, database_name, table_name):
def delete_table(self, database_name: str, table_name: str) -> None:
database = self.describe_database(database_name)
database.delete_table(table_name)
def describe_table(self, database_name, table_name):
def describe_table(self, database_name: str, table_name: str) -> TimestreamTable:
database = self.describe_database(database_name)
table = database.describe_table(table_name)
return table
return database.describe_table(table_name)
def list_tables(self, database_name):
def list_tables(self, database_name: str) -> Iterable[TimestreamTable]:
database = self.describe_database(database_name)
tables = database.list_tables()
return tables
return database.list_tables()
def update_table(
self,
database_name,
table_name,
retention_properties,
magnetic_store_write_properties,
):
database_name: str,
table_name: str,
retention_properties: Dict[str, int],
magnetic_store_write_properties: Dict[str, Any],
) -> TimestreamTable:
database = self.describe_database(database_name)
table = database.update_table(
return database.update_table(
table_name, retention_properties, magnetic_store_write_properties
)
return table
def write_records(self, database_name, table_name, records):
def write_records(
self, database_name: str, table_name: str, records: List[Dict[str, Any]]
) -> None:
database = self.describe_database(database_name)
table = database.describe_table(table_name)
table.write_records(records)
def describe_endpoints(self):
def describe_endpoints(self) -> Dict[str, List[Dict[str, Any]]]:
# https://docs.aws.amazon.com/timestream/latest/developerguide/Using-API.endpoint-discovery.how-it-works.html
# Usually, the address look like this:
# ingest-cell1.timestream.us-east-1.amazonaws.com
@ -208,13 +224,15 @@ class TimestreamWriteBackend(BaseBackend):
]
}
def list_tags_for_resource(self, resource_arn):
def list_tags_for_resource(
self, resource_arn: str
) -> Dict[str, List[Dict[str, str]]]:
return self.tagging_service.list_tags_for_resource(resource_arn)
def tag_resource(self, resource_arn, tags):
def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self.tagging_service.tag_resource(resource_arn, tags)
def untag_resource(self, resource_arn, tag_keys):
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagging_service.untag_resource_using_names(resource_arn, tag_keys)

View File

@ -1,19 +1,19 @@
import json
from moto.core.responses import BaseResponse
from .models import timestreamwrite_backends
from .models import timestreamwrite_backends, TimestreamWriteBackend
class TimestreamWriteResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="timestream-write")
@property
def timestreamwrite_backend(self):
def timestreamwrite_backend(self) -> TimestreamWriteBackend:
"""Return backend instance specific for this region."""
return timestreamwrite_backends[self.current_account][self.region]
def create_database(self):
def create_database(self) -> str:
database_name = self._get_param("DatabaseName")
kms_key_id = self._get_param("KmsKeyId")
tags = self._get_param("Tags")
@ -22,19 +22,19 @@ class TimestreamWriteResponse(BaseResponse):
)
return json.dumps(dict(Database=database.description()))
def delete_database(self):
def delete_database(self) -> str:
database_name = self._get_param("DatabaseName")
self.timestreamwrite_backend.delete_database(database_name=database_name)
return "{}"
def describe_database(self):
def describe_database(self) -> str:
database_name = self._get_param("DatabaseName")
database = self.timestreamwrite_backend.describe_database(
database_name=database_name
)
return json.dumps(dict(Database=database.description()))
def update_database(self):
def update_database(self) -> str:
database_name = self._get_param("DatabaseName")
kms_key_id = self._get_param("KmsKeyId")
database = self.timestreamwrite_backend.update_database(
@ -42,11 +42,11 @@ class TimestreamWriteResponse(BaseResponse):
)
return json.dumps(dict(Database=database.description()))
def list_databases(self):
def list_databases(self) -> str:
all_dbs = self.timestreamwrite_backend.list_databases()
return json.dumps(dict(Databases=[db.description() for db in all_dbs]))
def create_table(self):
def create_table(self) -> str:
database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName")
retention_properties = self._get_param("RetentionProperties")
@ -63,24 +63,24 @@ class TimestreamWriteResponse(BaseResponse):
)
return json.dumps(dict(Table=table.description()))
def delete_table(self):
def delete_table(self) -> str:
database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName")
self.timestreamwrite_backend.delete_table(database_name, table_name)
return "{}"
def describe_table(self):
def describe_table(self) -> str:
database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName")
table = self.timestreamwrite_backend.describe_table(database_name, table_name)
return json.dumps(dict(Table=table.description()))
def list_tables(self):
def list_tables(self) -> str:
database_name = self._get_param("DatabaseName")
tables = self.timestreamwrite_backend.list_tables(database_name)
return json.dumps(dict(Tables=[t.description() for t in tables]))
def update_table(self):
def update_table(self) -> str:
database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName")
retention_properties = self._get_param("RetentionProperties")
@ -95,7 +95,7 @@ class TimestreamWriteResponse(BaseResponse):
)
return json.dumps(dict(Table=table.description()))
def write_records(self):
def write_records(self) -> str:
database_name = self._get_param("DatabaseName")
table_name = self._get_param("TableName")
records = self._get_param("Records")
@ -109,22 +109,22 @@ class TimestreamWriteResponse(BaseResponse):
}
return json.dumps(resp)
def describe_endpoints(self):
def describe_endpoints(self) -> str:
resp = self.timestreamwrite_backend.describe_endpoints()
return json.dumps(resp)
def list_tags_for_resource(self):
def list_tags_for_resource(self) -> str:
resource_arn = self._get_param("ResourceARN")
tags = self.timestreamwrite_backend.list_tags_for_resource(resource_arn)
return json.dumps(tags)
def tag_resource(self):
def tag_resource(self) -> str:
resource_arn = self._get_param("ResourceARN")
tags = self._get_param("Tags")
self.timestreamwrite_backend.tag_resource(resource_arn, tags)
return "{}"
def untag_resource(self):
def untag_resource(self) -> str:
resource_arn = self._get_param("ResourceARN")
tag_keys = self._get_param("TagKeys")
self.timestreamwrite_backend.untag_resource(resource_arn, tag_keys)

View File

@ -2,10 +2,10 @@ from moto.core.exceptions import JsonRESTError
class ConflictException(JsonRESTError):
def __init__(self, message, **kwargs):
super().__init__("ConflictException", message, **kwargs)
def __init__(self, message: str):
super().__init__("ConflictException", message)
class BadRequestException(JsonRESTError):
def __init__(self, message, **kwargs):
super().__init__("BadRequestException", message, **kwargs)
def __init__(self, message: str):
super().__init__("BadRequestException", message)

View File

@ -1,4 +1,5 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api import state_manager
from moto.moto_api._internal import mock_random
@ -7,14 +8,14 @@ from .exceptions import ConflictException, BadRequestException
class BaseObject(BaseModel):
def camelCase(self, key):
def camelCase(self, key: str) -> str:
words = []
for word in key.split("_"):
words.append(word.title())
return "".join(words)
def gen_response_object(self):
response_object = dict()
def gen_response_object(self) -> Dict[str, Any]:
response_object: Dict[str, Any] = dict()
for key, value in self.__dict__.items():
if "_" in key:
response_object[self.camelCase(key)] = value
@ -23,30 +24,30 @@ class BaseObject(BaseModel):
return response_object
@property
def response_object(self):
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
return self.gen_response_object()
class FakeTranscriptionJob(BaseObject, ManagedState):
def __init__(
self,
account_id,
region_name,
transcription_job_name,
language_code,
media_sample_rate_hertz,
media_format,
media,
output_bucket_name,
output_key,
output_encryption_kms_key_id,
settings,
model_settings,
job_execution_settings,
content_redaction,
identify_language,
identify_multiple_languages,
language_options,
account_id: str,
region_name: str,
transcription_job_name: str,
language_code: Optional[str],
media_sample_rate_hertz: Optional[int],
media_format: Optional[str],
media: Dict[str, str],
output_bucket_name: Optional[str],
output_key: Optional[str],
output_encryption_kms_key_id: Optional[str],
settings: Optional[Dict[str, Any]],
model_settings: Optional[Dict[str, Optional[str]]],
job_execution_settings: Optional[Dict[str, Any]],
content_redaction: Optional[Dict[str, Any]],
identify_language: Optional[bool],
identify_multiple_languages: Optional[bool],
language_options: Optional[List[str]],
):
ManagedState.__init__(
self,
@ -61,12 +62,13 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self._region_name = region_name
self.transcription_job_name = transcription_job_name
self.language_code = language_code
self.language_codes = None
self.language_codes: Optional[List[Dict[str, Any]]] = None
self.media_sample_rate_hertz = media_sample_rate_hertz
self.media_format = media_format
self.media = media
self.transcript = None
self.start_time = self.completion_time = None
self.transcript: Optional[Dict[str, str]] = None
self.start_time: Optional[str] = None
self.completion_time: Optional[str] = None
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.failure_reason = None
self.settings = settings or {
@ -86,7 +88,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self.identify_language = identify_language
self.identify_multiple_languages = identify_multiple_languages
self.language_options = language_options
self.identified_language_score = (None,)
self.identified_language_score: Optional[float] = None
self._output_bucket_name = output_bucket_name
self.output_key = output_key
self._output_encryption_kms_key_id = output_encryption_kms_key_id
@ -94,7 +96,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
)
def response_object(self, response_type):
def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
response_field_dict = {
"CREATE": [
"TranscriptionJobName",
@ -162,7 +164,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
if k in response_fields and v is not None and v != [None]
}
def advance(self):
def advance(self) -> None:
old_status = self.status
super().advance()
new_status = self.status
@ -191,20 +193,20 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self.identified_language_score = 0.999645948
# Identify first two languages passed in language_options
# If none is set, default to "en-US"
self.language_codes = []
self.language_codes: List[Dict[str, Any]] = [] # type: ignore[no-redef]
if self.language_options is None or len(self.language_options) == 0:
self.language_codes.append(
self.language_codes.append( # type: ignore
{"LanguageCode": "en-US", "DurationInSeconds": 123.0}
)
else:
self.language_codes.append(
self.language_codes.append( # type: ignore
{
"LanguageCode": self.language_options[0],
"DurationInSeconds": 123.0,
}
)
if len(self.language_options) > 1:
self.language_codes.append(
self.language_codes.append( # type: ignore
{
"LanguageCode": self.language_options[1],
"DurationInSeconds": 321.0,
@ -229,12 +231,12 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
class FakeVocabulary(BaseObject, ManagedState):
def __init__(
self,
account_id,
region_name,
vocabulary_name,
language_code,
phrases,
vocabulary_file_uri,
account_id: str,
region_name: str,
vocabulary_name: str,
language_code: str,
phrases: Optional[List[str]],
vocabulary_file_uri: Optional[str],
):
# Configured ManagedState
super().__init__(
@ -247,11 +249,11 @@ class FakeVocabulary(BaseObject, ManagedState):
self.language_code = language_code
self.phrases = phrases
self.vocabulary_file_uri = vocabulary_file_uri
self.last_modified_time = None
self.last_modified_time: Optional[str] = None
self.failure_reason = None
self.download_uri = f"https://s3.{region_name}.amazonaws.com/aws-transcribe-dictionary-model-{region_name}-prod/{account_id}/{vocabulary_name}/{mock_random.uuid4()}/input.txt"
def response_object(self, response_type):
def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
response_field_dict = {
"CREATE": [
"VocabularyName",
@ -284,7 +286,7 @@ class FakeVocabulary(BaseObject, ManagedState):
if k in response_fields and v is not None and v != [None]
}
def advance(self):
def advance(self) -> None:
old_status = self.status
super().advance()
new_status = self.status
@ -296,17 +298,17 @@ class FakeVocabulary(BaseObject, ManagedState):
class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
def __init__(
self,
region_name,
medical_transcription_job_name,
language_code,
media_sample_rate_hertz,
media_format,
media,
output_bucket_name,
output_encryption_kms_key_id,
settings,
specialty,
job_type,
region_name: str,
medical_transcription_job_name: str,
language_code: str,
media_sample_rate_hertz: Optional[int],
media_format: Optional[str],
media: Dict[str, str],
output_bucket_name: str,
output_encryption_kms_key_id: Optional[str],
settings: Optional[Dict[str, Any]],
specialty: str,
job_type: str,
):
ManagedState.__init__(
self,
@ -323,8 +325,9 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
self.media_sample_rate_hertz = media_sample_rate_hertz
self.media_format = media_format
self.media = media
self.transcript = None
self.start_time = self.completion_time = None
self.transcript: Optional[Dict[str, str]] = None
self.start_time: Optional[str] = None
self.completion_time: Optional[str] = None
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.failure_reason = None
self.settings = settings or {
@ -337,7 +340,7 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
self._output_encryption_kms_key_id = output_encryption_kms_key_id
self.output_location_type = "CUSTOMER_BUCKET"
def response_object(self, response_type):
def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
response_field_dict = {
"CREATE": [
"MedicalTranscriptionJobName",
@ -396,7 +399,7 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
if k in response_fields and v is not None and v != [None]
}
def advance(self):
def advance(self) -> None:
old_status = self.status
super().advance()
new_status = self.status
@ -425,11 +428,11 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
class FakeMedicalVocabulary(FakeVocabulary):
def __init__(
self,
account_id,
region_name,
vocabulary_name,
language_code,
vocabulary_file_uri,
account_id: str,
region_name: str,
vocabulary_name: str,
language_code: str,
vocabulary_file_uri: Optional[str],
):
super().__init__(
account_id,
@ -450,12 +453,12 @@ class FakeMedicalVocabulary(FakeVocabulary):
class TranscribeBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.medical_transcriptions = {}
self.transcriptions = {}
self.medical_vocabularies = {}
self.vocabularies = {}
self.medical_transcriptions: Dict[str, FakeMedicalTranscriptionJob] = {}
self.transcriptions: Dict[str, FakeTranscriptionJob] = {}
self.medical_vocabularies: Dict[str, FakeMedicalVocabulary] = {}
self.vocabularies: Dict[str, FakeVocabulary] = {}
state_manager.register_default_transition(
"transcribe::vocabulary", transition={"progression": "manual", "times": 1}
@ -474,7 +477,9 @@ class TranscribeBackend(BaseBackend):
)
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint services."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "transcribe"
@ -482,15 +487,29 @@ class TranscribeBackend(BaseBackend):
service_region, zones, "transcribestreaming"
)
def start_transcription_job(self, **kwargs):
name = kwargs.get("transcription_job_name")
if name in self.transcriptions:
def start_transcription_job(
self,
transcription_job_name: str,
language_code: Optional[str],
media_sample_rate_hertz: Optional[int],
media_format: Optional[str],
media: Dict[str, str],
output_bucket_name: Optional[str],
output_key: Optional[str],
output_encryption_kms_key_id: Optional[str],
settings: Optional[Dict[str, Any]],
model_settings: Optional[Dict[str, Optional[str]]],
job_execution_settings: Optional[Dict[str, Any]],
content_redaction: Optional[Dict[str, Any]],
identify_language: Optional[bool],
identify_multiple_languages: Optional[bool],
language_options: Optional[List[str]],
) -> Dict[str, Any]:
if transcription_job_name in self.transcriptions:
raise ConflictException(
message="The requested job name already exists. Use a different job name."
)
settings = kwargs.get("settings")
vocabulary_name = settings.get("VocabularyName") if settings else None
if vocabulary_name and vocabulary_name not in self.vocabularies:
raise BadRequestException(
@ -501,36 +520,45 @@ class TranscribeBackend(BaseBackend):
transcription_job_object = FakeTranscriptionJob(
account_id=self.account_id,
region_name=self.region_name,
transcription_job_name=name,
language_code=kwargs.get("language_code"),
media_sample_rate_hertz=kwargs.get("media_sample_rate_hertz"),
media_format=kwargs.get("media_format"),
media=kwargs.get("media"),
output_bucket_name=kwargs.get("output_bucket_name"),
output_key=kwargs.get("output_key"),
output_encryption_kms_key_id=kwargs.get("output_encryption_kms_key_id"),
transcription_job_name=transcription_job_name,
language_code=language_code,
media_sample_rate_hertz=media_sample_rate_hertz,
media_format=media_format,
media=media,
output_bucket_name=output_bucket_name,
output_key=output_key,
output_encryption_kms_key_id=output_encryption_kms_key_id,
settings=settings,
model_settings=kwargs.get("model_settings"),
job_execution_settings=kwargs.get("job_execution_settings"),
content_redaction=kwargs.get("content_redaction"),
identify_language=kwargs.get("identify_language"),
identify_multiple_languages=kwargs.get("identify_multiple_languages"),
language_options=kwargs.get("language_options"),
model_settings=model_settings,
job_execution_settings=job_execution_settings,
content_redaction=content_redaction,
identify_language=identify_language,
identify_multiple_languages=identify_multiple_languages,
language_options=language_options,
)
self.transcriptions[name] = transcription_job_object
self.transcriptions[transcription_job_name] = transcription_job_object
return transcription_job_object.response_object("CREATE")
def start_medical_transcription_job(self, **kwargs):
def start_medical_transcription_job(
self,
medical_transcription_job_name: str,
language_code: str,
media_sample_rate_hertz: Optional[int],
media_format: Optional[str],
media: Dict[str, str],
output_bucket_name: str,
output_encryption_kms_key_id: Optional[str],
settings: Optional[Dict[str, Any]],
specialty: str,
type_: str,
) -> Dict[str, Any]:
name = kwargs.get("medical_transcription_job_name")
if name in self.medical_transcriptions:
if medical_transcription_job_name in self.medical_transcriptions:
raise ConflictException(
message="The requested job name already exists. Use a different job name."
)
settings = kwargs.get("settings")
vocabulary_name = settings.get("VocabularyName") if settings else None
if vocabulary_name and vocabulary_name not in self.medical_vocabularies:
raise BadRequestException(
@ -540,23 +568,25 @@ class TranscribeBackend(BaseBackend):
transcription_job_object = FakeMedicalTranscriptionJob(
region_name=self.region_name,
medical_transcription_job_name=name,
language_code=kwargs.get("language_code"),
media_sample_rate_hertz=kwargs.get("media_sample_rate_hertz"),
media_format=kwargs.get("media_format"),
media=kwargs.get("media"),
output_bucket_name=kwargs.get("output_bucket_name"),
output_encryption_kms_key_id=kwargs.get("output_encryption_kms_key_id"),
medical_transcription_job_name=medical_transcription_job_name,
language_code=language_code,
media_sample_rate_hertz=media_sample_rate_hertz,
media_format=media_format,
media=media,
output_bucket_name=output_bucket_name,
output_encryption_kms_key_id=output_encryption_kms_key_id,
settings=settings,
specialty=kwargs.get("specialty"),
job_type=kwargs.get("type"),
specialty=specialty,
job_type=type_,
)
self.medical_transcriptions[name] = transcription_job_object
self.medical_transcriptions[
medical_transcription_job_name
] = transcription_job_object
return transcription_job_object.response_object("CREATE")
def get_transcription_job(self, transcription_job_name):
def get_transcription_job(self, transcription_job_name: str) -> Dict[str, Any]:
try:
job = self.transcriptions[transcription_job_name]
job.advance() # Fakes advancement through statuses.
@ -567,7 +597,9 @@ class TranscribeBackend(BaseBackend):
"Check the job name and try your request again."
)
def get_medical_transcription_job(self, medical_transcription_job_name):
def get_medical_transcription_job(
self, medical_transcription_job_name: str
) -> Dict[str, Any]:
try:
job = self.medical_transcriptions[medical_transcription_job_name]
job.advance() # Fakes advancement through statuses.
@ -578,7 +610,7 @@ class TranscribeBackend(BaseBackend):
"Check the job name and try your request again."
)
def delete_transcription_job(self, transcription_job_name):
def delete_transcription_job(self, transcription_job_name: str) -> None:
try:
del self.transcriptions[transcription_job_name]
except KeyError:
@ -587,7 +619,9 @@ class TranscribeBackend(BaseBackend):
"Check the job name and try your request again."
)
def delete_medical_transcription_job(self, medical_transcription_job_name):
def delete_medical_transcription_job(
self, medical_transcription_job_name: str
) -> None:
try:
del self.medical_transcriptions[medical_transcription_job_name]
except KeyError:
@ -597,8 +631,12 @@ class TranscribeBackend(BaseBackend):
)
def list_transcription_jobs(
self, state_equals, job_name_contains, next_token, max_results
):
self,
state_equals: str,
job_name_contains: str,
next_token: str,
max_results: int,
) -> Dict[str, Any]:
jobs = list(self.transcriptions.values())
if state_equals:
@ -615,7 +653,7 @@ class TranscribeBackend(BaseBackend):
) # Arbitrarily selected...
jobs_paginated = jobs[start_offset:end_offset]
response = {
response: Dict[str, Any] = {
"TranscriptionJobSummaries": [
job.response_object("LIST") for job in jobs_paginated
]
@ -627,8 +665,8 @@ class TranscribeBackend(BaseBackend):
return response
def list_medical_transcription_jobs(
self, status, job_name_contains, next_token, max_results
):
self, status: str, job_name_contains: str, next_token: str, max_results: int
) -> Dict[str, Any]:
jobs = list(self.medical_transcriptions.values())
if status:
@ -647,7 +685,7 @@ class TranscribeBackend(BaseBackend):
) # Arbitrarily selected...
jobs_paginated = jobs[start_offset:end_offset]
response = {
response: Dict[str, Any] = {
"MedicalTranscriptionJobSummaries": [
job.response_object("LIST") for job in jobs_paginated
]
@ -658,12 +696,13 @@ class TranscribeBackend(BaseBackend):
response["Status"] = status
return response
def create_vocabulary(self, **kwargs):
vocabulary_name = kwargs.get("vocabulary_name")
language_code = kwargs.get("language_code")
phrases = kwargs.get("phrases")
vocabulary_file_uri = kwargs.get("vocabulary_file_uri")
def create_vocabulary(
self,
vocabulary_name: str,
language_code: str,
phrases: Optional[List[str]],
vocabulary_file_uri: Optional[str],
) -> Dict[str, Any]:
if (
phrases is not None
and vocabulary_file_uri is not None
@ -698,12 +737,12 @@ class TranscribeBackend(BaseBackend):
return vocabulary_object.response_object("CREATE")
def create_medical_vocabulary(self, **kwargs):
vocabulary_name = kwargs.get("vocabulary_name")
language_code = kwargs.get("language_code")
vocabulary_file_uri = kwargs.get("vocabulary_file_uri")
def create_medical_vocabulary(
self,
vocabulary_name: str,
language_code: str,
vocabulary_file_uri: Optional[str],
) -> Dict[str, Any]:
if vocabulary_name in self.medical_vocabularies:
raise ConflictException(
message="The requested vocabulary name already exists. "
@ -722,7 +761,7 @@ class TranscribeBackend(BaseBackend):
return medical_vocabulary_object.response_object("CREATE")
def get_vocabulary(self, vocabulary_name):
def get_vocabulary(self, vocabulary_name: str) -> Dict[str, Any]:
try:
job = self.vocabularies[vocabulary_name]
job.advance() # Fakes advancement through statuses.
@ -733,7 +772,7 @@ class TranscribeBackend(BaseBackend):
"Check the vocabulary name and try your request again."
)
def get_medical_vocabulary(self, vocabulary_name):
def get_medical_vocabulary(self, vocabulary_name: str) -> Dict[str, Any]:
try:
job = self.medical_vocabularies[vocabulary_name]
job.advance() # Fakes advancement through statuses.
@ -744,7 +783,7 @@ class TranscribeBackend(BaseBackend):
"Check the vocabulary name and try your request again."
)
def delete_vocabulary(self, vocabulary_name):
def delete_vocabulary(self, vocabulary_name: str) -> None:
try:
del self.vocabularies[vocabulary_name]
except KeyError:
@ -752,7 +791,7 @@ class TranscribeBackend(BaseBackend):
message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again."
)
def delete_medical_vocabulary(self, vocabulary_name):
def delete_medical_vocabulary(self, vocabulary_name: str) -> None:
try:
del self.medical_vocabularies[vocabulary_name]
except KeyError:
@ -760,7 +799,9 @@ class TranscribeBackend(BaseBackend):
message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again."
)
def list_vocabularies(self, state_equals, name_contains, next_token, max_results):
def list_vocabularies(
self, state_equals: str, name_contains: str, next_token: str, max_results: int
) -> Dict[str, Any]:
vocabularies = list(self.vocabularies.values())
if state_equals:
@ -783,7 +824,7 @@ class TranscribeBackend(BaseBackend):
) # Arbitrarily selected...
vocabularies_paginated = vocabularies[start_offset:end_offset]
response = {
response: Dict[str, Any] = {
"Vocabularies": [
vocabulary.response_object("LIST")
for vocabulary in vocabularies_paginated
@ -796,8 +837,8 @@ class TranscribeBackend(BaseBackend):
return response
def list_medical_vocabularies(
self, state_equals, name_contains, next_token, max_results
):
self, state_equals: str, name_contains: str, next_token: str, max_results: int
) -> Dict[str, Any]:
vocabularies = list(self.medical_vocabularies.values())
if state_equals:
@ -820,7 +861,7 @@ class TranscribeBackend(BaseBackend):
) # Arbitrarily selected...
vocabularies_paginated = vocabularies[start_offset:end_offset]
response = {
response: Dict[str, Any] = {
"Vocabularies": [
vocabulary.response_object("LIST")
for vocabulary in vocabularies_paginated

View File

@ -2,26 +2,19 @@ import json
from moto.core.responses import BaseResponse
from moto.utilities.aws_headers import amzn_request_id
from .models import transcribe_backends
from .models import transcribe_backends, TranscribeBackend
class TranscribeResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="transcribe")
@property
def transcribe_backend(self):
def transcribe_backend(self) -> TranscribeBackend:
return transcribe_backends[self.current_account][self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
@amzn_request_id
def start_transcription_job(self):
def start_transcription_job(self) -> str:
name = self._get_param("TranscriptionJobName")
response = self.transcribe_backend.start_transcription_job(
transcription_job_name=name,
@ -43,7 +36,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def start_medical_transcription_job(self):
def start_medical_transcription_job(self) -> str:
name = self._get_param("MedicalTranscriptionJobName")
response = self.transcribe_backend.start_medical_transcription_job(
medical_transcription_job_name=name,
@ -55,12 +48,12 @@ class TranscribeResponse(BaseResponse):
output_encryption_kms_key_id=self._get_param("OutputEncryptionKMSKeyId"),
settings=self._get_param("Settings"),
specialty=self._get_param("Specialty"),
type=self._get_param("Type"),
type_=self._get_param("Type"),
)
return json.dumps(response)
@amzn_request_id
def list_transcription_jobs(self):
def list_transcription_jobs(self) -> str:
state_equals = self._get_param("Status")
job_name_contains = self._get_param("JobNameContains")
next_token = self._get_param("NextToken")
@ -75,7 +68,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def list_medical_transcription_jobs(self):
def list_medical_transcription_jobs(self) -> str:
status = self._get_param("Status")
job_name_contains = self._get_param("JobNameContains")
next_token = self._get_param("NextToken")
@ -90,7 +83,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def get_transcription_job(self):
def get_transcription_job(self) -> str:
transcription_job_name = self._get_param("TranscriptionJobName")
response = self.transcribe_backend.get_transcription_job(
transcription_job_name=transcription_job_name
@ -98,7 +91,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def get_medical_transcription_job(self):
def get_medical_transcription_job(self) -> str:
medical_transcription_job_name = self._get_param("MedicalTranscriptionJobName")
response = self.transcribe_backend.get_medical_transcription_job(
medical_transcription_job_name=medical_transcription_job_name
@ -106,23 +99,23 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def delete_transcription_job(self):
def delete_transcription_job(self) -> str:
transcription_job_name = self._get_param("TranscriptionJobName")
response = self.transcribe_backend.delete_transcription_job(
self.transcribe_backend.delete_transcription_job(
transcription_job_name=transcription_job_name
)
return json.dumps(response)
return "{}"
@amzn_request_id
def delete_medical_transcription_job(self):
def delete_medical_transcription_job(self) -> str:
medical_transcription_job_name = self._get_param("MedicalTranscriptionJobName")
response = self.transcribe_backend.delete_medical_transcription_job(
self.transcribe_backend.delete_medical_transcription_job(
medical_transcription_job_name=medical_transcription_job_name
)
return json.dumps(response)
return "{}"
@amzn_request_id
def create_vocabulary(self):
def create_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName")
language_code = self._get_param("LanguageCode")
phrases = self._get_param("Phrases")
@ -136,7 +129,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def create_medical_vocabulary(self):
def create_medical_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName")
language_code = self._get_param("LanguageCode")
vocabulary_file_uri = self._get_param("VocabularyFileUri")
@ -148,7 +141,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def get_vocabulary(self):
def get_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName")
response = self.transcribe_backend.get_vocabulary(
vocabulary_name=vocabulary_name
@ -156,7 +149,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def get_medical_vocabulary(self):
def get_medical_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName")
response = self.transcribe_backend.get_medical_vocabulary(
vocabulary_name=vocabulary_name
@ -164,7 +157,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def list_vocabularies(self):
def list_vocabularies(self) -> str:
state_equals = self._get_param("StateEquals")
name_contains = self._get_param("NameContains")
next_token = self._get_param("NextToken")
@ -179,7 +172,7 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def list_medical_vocabularies(self):
def list_medical_vocabularies(self) -> str:
state_equals = self._get_param("StateEquals")
name_contains = self._get_param("NameContains")
next_token = self._get_param("NextToken")
@ -194,17 +187,15 @@ class TranscribeResponse(BaseResponse):
return json.dumps(response)
@amzn_request_id
def delete_vocabulary(self):
def delete_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName")
response = self.transcribe_backend.delete_vocabulary(
vocabulary_name=vocabulary_name
)
return json.dumps(response)
self.transcribe_backend.delete_vocabulary(vocabulary_name=vocabulary_name)
return "{}"
@amzn_request_id
def delete_medical_vocabulary(self):
def delete_medical_vocabulary(self) -> str:
vocabulary_name = self._get_param("VocabularyName")
response = self.transcribe_backend.delete_medical_vocabulary(
self.transcribe_backend.delete_medical_vocabulary(
vocabulary_name=vocabulary_name
)
return json.dumps(response)
return "{}"

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u*
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u*,moto/t*
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract