moto/moto/transcribe/models.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

827 lines
30 KiB
Python
Raw Normal View History

from datetime import datetime, timedelta
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api import state_manager
from moto.moto_api._internal import mock_random
from moto.moto_api._internal.managed_state_model import ManagedState
from .exceptions import ConflictException, BadRequestException
class BaseObject(BaseModel):
def camelCase(self, key):
words = []
2022-01-14 19:51:49 +00:00
for word in key.split("_"):
words.append(word.title())
return "".join(words)
def gen_response_object(self):
response_object = dict()
for key, value in self.__dict__.items():
if "_" in key:
response_object[self.camelCase(key)] = value
else:
response_object[key[0].upper() + key[1:]] = value
return response_object
@property
def response_object(self):
return self.gen_response_object()
class FakeTranscriptionJob(BaseObject, ManagedState):
def __init__(
self,
2022-08-13 09:49:43 +00:00
account_id,
region_name,
transcription_job_name,
language_code,
media_sample_rate_hertz,
media_format,
media,
output_bucket_name,
output_key,
output_encryption_kms_key_id,
settings,
model_settings,
job_execution_settings,
content_redaction,
identify_language,
language_options,
):
ManagedState.__init__(
self,
"transcribe::transcriptionjob",
transitions=[
(None, "QUEUED"),
("QUEUED", "IN_PROGRESS"),
("IN_PROGRESS", "COMPLETED"),
],
)
2022-08-13 09:49:43 +00:00
self._account_id = account_id
self._region_name = region_name
self.transcription_job_name = transcription_job_name
self.language_code = language_code
self.media_sample_rate_hertz = media_sample_rate_hertz
self.media_format = media_format
self.media = media
self.transcript = None
self.start_time = self.completion_time = None
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.failure_reason = None
self.settings = settings or {
"ChannelIdentification": False,
"ShowAlternatives": False,
"ShowSpeakerLabels": False,
}
self.model_settings = model_settings or {"LanguageModelName": None}
self.job_execution_settings = job_execution_settings or {
"AllowDeferredExecution": False,
"DataAccessRoleArn": None,
}
self.content_redaction = content_redaction or {
"RedactionType": None,
"RedactionOutput": None,
}
self.identify_language = identify_language
self.language_options = language_options
self.identified_language_score = (None,)
self._output_bucket_name = output_bucket_name
self.output_key = output_key
self._output_encryption_kms_key_id = output_encryption_kms_key_id
self.output_location_type = (
"CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
)
def response_object(self, response_type):
response_field_dict = {
"CREATE": [
"TranscriptionJobName",
"TranscriptionJobStatus",
"LanguageCode",
"MediaFormat",
"Media",
"Settings",
"StartTime",
"CreationTime",
"IdentifyLanguage",
"LanguageOptions",
"JobExecutionSettings",
],
"GET": [
"TranscriptionJobName",
"TranscriptionJobStatus",
"LanguageCode",
"MediaSampleRateHertz",
"MediaFormat",
"Media",
"Settings",
"Transcript",
"StartTime",
"CreationTime",
"CompletionTime",
"IdentifyLanguage",
"LanguageOptions",
"IdentifiedLanguageScore",
],
"LIST": [
"TranscriptionJobName",
"CreationTime",
"StartTime",
"CompletionTime",
"LanguageCode",
"TranscriptionJobStatus",
"FailureReason",
"IdentifyLanguage",
"IdentifiedLanguageScore",
"OutputLocationType",
],
}
response_fields = response_field_dict[response_type]
response_object = self.gen_response_object()
response_object["TranscriptionJobStatus"] = self.status
if response_type != "LIST":
return {
"TranscriptionJob": {
k: v
for k, v in response_object.items()
if k in response_fields and v is not None and v != [None]
}
}
else:
return {
k: v
for k, v in response_object.items()
if k in response_fields and v is not None and v != [None]
}
def advance(self):
old_status = self.status
super().advance()
new_status = self.status
if old_status == new_status:
return
if new_status == "IN_PROGRESS":
self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if not self.media_sample_rate_hertz:
self.media_sample_rate_hertz = 44100
if not self.media_format:
file_ext = self.media["MediaFileUri"].split(".")[-1].lower()
self.media_format = (
file_ext if file_ext in ["mp3", "mp4", "wav", "flac"] else "mp3"
)
if self.identify_language:
self.identified_language_score = 0.999645948
# Simply identify first language passed in lanugage_options
# If non is set default to "en-US"
if self.language_options is not None and len(self.language_options) > 0:
self.language_code = self.language_options[0]
else:
self.language_code = "en-US"
elif new_status == "COMPLETED":
self.completion_time = (datetime.now() + timedelta(seconds=10)).strftime(
"%Y-%m-%d %H:%M:%S"
)
if self._output_bucket_name:
transcript_file_uri = "https://s3.{0}.amazonaws.com/{1}/".format(
self._region_name, self._output_bucket_name
)
transcript_file_uri = (
transcript_file_uri
+ "{0}/{1}.json".format(
self.output_key, self.transcription_job_name
)
if self.output_key is not None
else transcript_file_uri
+ "{transcription_job_name}.json".format(
transcription_job_name=self.transcription_job_name
)
)
self.output_location_type = "CUSTOMER_BUCKET"
else:
transcript_file_uri = "https://s3.{0}.amazonaws.com/aws-transcribe-{0}-prod/{1}/{2}/{3}/asrOutput.json".format( # noqa: E501
self._region_name,
2022-08-13 09:49:43 +00:00
self._account_id,
self.transcription_job_name,
mock_random.uuid4(),
)
self.output_location_type = "SERVICE_BUCKET"
self.transcript = {"TranscriptFileUri": transcript_file_uri}
class FakeVocabulary(BaseObject, ManagedState):
def __init__(
2022-08-13 09:49:43 +00:00
self,
account_id,
region_name,
vocabulary_name,
language_code,
phrases,
vocabulary_file_uri,
):
# Configured ManagedState
super().__init__(
"transcribe::vocabulary",
transitions=[(None, "PENDING"), ("PENDING", "READY")],
)
# Configure internal properties
self._region_name = region_name
self.vocabulary_name = vocabulary_name
self.language_code = language_code
self.phrases = phrases
self.vocabulary_file_uri = vocabulary_file_uri
self.last_modified_time = None
self.failure_reason = None
self.download_uri = "https://s3.{0}.amazonaws.com/aws-transcribe-dictionary-model-{0}-prod/{1}/{2}/{3}/input.txt".format( # noqa: E501
region_name, account_id, vocabulary_name, mock_random.uuid4()
)
def response_object(self, response_type):
response_field_dict = {
"CREATE": [
"VocabularyName",
"LanguageCode",
"VocabularyState",
"LastModifiedTime",
"FailureReason",
],
"GET": [
"VocabularyName",
"LanguageCode",
"VocabularyState",
"LastModifiedTime",
"FailureReason",
"DownloadUri",
],
"LIST": [
"VocabularyName",
"LanguageCode",
"LastModifiedTime",
"VocabularyState",
],
}
response_fields = response_field_dict[response_type]
response_object = self.gen_response_object()
response_object["VocabularyState"] = self.status
return {
k: v
for k, v in response_object.items()
if k in response_fields and v is not None and v != [None]
}
def advance(self):
old_status = self.status
super().advance()
new_status = self.status
if old_status != new_status:
self.last_modified_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
class FakeMedicalTranscriptionJob(BaseObject, ManagedState):
def __init__(
self,
region_name,
medical_transcription_job_name,
language_code,
media_sample_rate_hertz,
media_format,
media,
output_bucket_name,
output_encryption_kms_key_id,
settings,
specialty,
2022-01-14 19:51:49 +00:00
job_type,
):
ManagedState.__init__(
self,
"transcribe::medicaltranscriptionjob",
transitions=[
(None, "QUEUED"),
("QUEUED", "IN_PROGRESS"),
("IN_PROGRESS", "COMPLETED"),
],
)
self._region_name = region_name
self.medical_transcription_job_name = medical_transcription_job_name
self.language_code = language_code
self.media_sample_rate_hertz = media_sample_rate_hertz
self.media_format = media_format
self.media = media
self.transcript = None
self.start_time = self.completion_time = None
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.failure_reason = None
self.settings = settings or {
"ChannelIdentification": False,
"ShowAlternatives": False,
}
self.specialty = specialty
2022-01-14 19:51:49 +00:00
self.type = job_type
self._output_bucket_name = output_bucket_name
self._output_encryption_kms_key_id = output_encryption_kms_key_id
self.output_location_type = "CUSTOMER_BUCKET"
def response_object(self, response_type):
response_field_dict = {
"CREATE": [
"MedicalTranscriptionJobName",
"TranscriptionJobStatus",
"LanguageCode",
"MediaFormat",
"Media",
"StartTime",
"CreationTime",
"Specialty",
"Type",
],
"GET": [
"MedicalTranscriptionJobName",
"TranscriptionJobStatus",
"LanguageCode",
"MediaSampleRateHertz",
"MediaFormat",
"Media",
"Transcript",
"StartTime",
"CreationTime",
"CompletionTime",
"Settings",
"Specialty",
"Type",
],
"LIST": [
"MedicalTranscriptionJobName",
"CreationTime",
"StartTime",
"CompletionTime",
"LanguageCode",
"TranscriptionJobStatus",
"FailureReason",
"OutputLocationType",
"Specialty",
"Type",
],
}
response_fields = response_field_dict[response_type]
response_object = self.gen_response_object()
response_object["TranscriptionJobStatus"] = self.status
if response_type != "LIST":
return {
"MedicalTranscriptionJob": {
k: v
for k, v in response_object.items()
if k in response_fields and v is not None and v != [None]
}
}
else:
return {
k: v
for k, v in response_object.items()
if k in response_fields and v is not None and v != [None]
}
def advance(self):
old_status = self.status
super().advance()
new_status = self.status
if old_status == new_status:
return
if new_status == "IN_PROGRESS":
self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if not self.media_sample_rate_hertz:
self.media_sample_rate_hertz = 44100
if not self.media_format:
file_ext = self.media["MediaFileUri"].split(".")[-1].lower()
self.media_format = (
file_ext if file_ext in ["mp3", "mp4", "wav", "flac"] else "mp3"
)
elif new_status == "COMPLETED":
self.completion_time = (datetime.now() + timedelta(seconds=10)).strftime(
"%Y-%m-%d %H:%M:%S"
)
self.transcript = {
"TranscriptFileUri": "https://s3.{}.amazonaws.com/{}/medical/{}.json".format(
self._region_name,
self._output_bucket_name,
self.medical_transcription_job_name,
)
}
class FakeMedicalVocabulary(FakeVocabulary):
def __init__(
2022-08-13 09:49:43 +00:00
self,
account_id,
region_name,
vocabulary_name,
language_code,
vocabulary_file_uri,
):
super().__init__(
2022-08-13 09:49:43 +00:00
account_id,
region_name,
vocabulary_name,
language_code=language_code,
phrases=None,
vocabulary_file_uri=vocabulary_file_uri,
)
self.model_name = "transcribe::medicalvocabulary"
self._region_name = region_name
self.vocabulary_name = vocabulary_name
self.language_code = language_code
self.vocabulary_file_uri = vocabulary_file_uri
self.last_modified_time = None
self.failure_reason = None
self.download_uri = "https://s3.us-east-1.amazonaws.com/aws-transcribe-dictionary-model-{}-prod/{}/medical/{}/{}/input.txt".format( # noqa: E501
region_name, account_id, self.vocabulary_name, mock_random.uuid4()
)
class TranscribeBackend(BaseBackend):
def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.medical_transcriptions = {}
self.transcriptions = {}
self.medical_vocabularies = {}
self.vocabularies = {}
state_manager.register_default_transition(
"transcribe::vocabulary", transition={"progression": "manual", "times": 1}
)
state_manager.register_default_transition(
"transcribe::medicalvocabulary",
transition={"progression": "manual", "times": 1},
)
state_manager.register_default_transition(
"transcribe::transcriptionjob",
transition={"progression": "manual", "times": 1},
)
state_manager.register_default_transition(
"transcribe::medicaltranscriptionjob",
transition={"progression": "manual", "times": 1},
)
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint services."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "transcribe"
) + BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "transcribestreaming"
)
def start_transcription_job(self, **kwargs):
name = kwargs.get("transcription_job_name")
if name in self.transcriptions:
raise ConflictException(
message="The requested job name already exists. Use a different job name."
)
settings = kwargs.get("settings")
vocabulary_name = settings.get("VocabularyName") if settings else None
if vocabulary_name and vocabulary_name not in self.vocabularies:
raise BadRequestException(
message="The requested vocabulary couldn't be found. "
"Check the vocabulary name and try your request again."
)
transcription_job_object = FakeTranscriptionJob(
2022-08-13 09:49:43 +00:00
account_id=self.account_id,
region_name=self.region_name,
transcription_job_name=name,
language_code=kwargs.get("language_code"),
media_sample_rate_hertz=kwargs.get("media_sample_rate_hertz"),
media_format=kwargs.get("media_format"),
media=kwargs.get("media"),
output_bucket_name=kwargs.get("output_bucket_name"),
output_key=kwargs.get("output_key"),
output_encryption_kms_key_id=kwargs.get("output_encryption_kms_key_id"),
settings=settings,
model_settings=kwargs.get("model_settings"),
job_execution_settings=kwargs.get("job_execution_settings"),
content_redaction=kwargs.get("content_redaction"),
identify_language=kwargs.get("identify_language"),
language_options=kwargs.get("language_options"),
)
self.transcriptions[name] = transcription_job_object
return transcription_job_object.response_object("CREATE")
def start_medical_transcription_job(self, **kwargs):
name = kwargs.get("medical_transcription_job_name")
if name in self.medical_transcriptions:
raise ConflictException(
message="The requested job name already exists. Use a different job name."
)
settings = kwargs.get("settings")
vocabulary_name = settings.get("VocabularyName") if settings else None
if vocabulary_name and vocabulary_name not in self.medical_vocabularies:
raise BadRequestException(
message="The requested vocabulary couldn't be found. "
"Check the vocabulary name and try your request again."
)
transcription_job_object = FakeMedicalTranscriptionJob(
region_name=self.region_name,
medical_transcription_job_name=name,
language_code=kwargs.get("language_code"),
media_sample_rate_hertz=kwargs.get("media_sample_rate_hertz"),
media_format=kwargs.get("media_format"),
media=kwargs.get("media"),
output_bucket_name=kwargs.get("output_bucket_name"),
output_encryption_kms_key_id=kwargs.get("output_encryption_kms_key_id"),
settings=settings,
specialty=kwargs.get("specialty"),
2022-01-14 19:51:49 +00:00
job_type=kwargs.get("type"),
)
self.medical_transcriptions[name] = transcription_job_object
return transcription_job_object.response_object("CREATE")
def get_transcription_job(self, transcription_job_name):
try:
job = self.transcriptions[transcription_job_name]
job.advance() # Fakes advancement through statuses.
return job.response_object("GET")
except KeyError:
raise BadRequestException(
message="The requested job couldn't be found. "
"Check the job name and try your request again."
)
def get_medical_transcription_job(self, medical_transcription_job_name):
try:
job = self.medical_transcriptions[medical_transcription_job_name]
job.advance() # Fakes advancement through statuses.
return job.response_object("GET")
except KeyError:
raise BadRequestException(
message="The requested job couldn't be found. "
"Check the job name and try your request again."
)
def delete_transcription_job(self, transcription_job_name):
try:
del self.transcriptions[transcription_job_name]
except KeyError:
raise BadRequestException(
message="The requested job couldn't be found. "
"Check the job name and try your request again."
)
def delete_medical_transcription_job(self, medical_transcription_job_name):
try:
del self.medical_transcriptions[medical_transcription_job_name]
except KeyError:
raise BadRequestException(
message="The requested job couldn't be found. "
"Check the job name and try your request again."
)
def list_transcription_jobs(
self, state_equals, job_name_contains, next_token, max_results
):
jobs = list(self.transcriptions.values())
if state_equals:
jobs = [job for job in jobs if job.status == state_equals]
if job_name_contains:
jobs = [
job for job in jobs if job_name_contains in job.transcription_job_name
]
start_offset = int(next_token) if next_token else 0
end_offset = start_offset + (
max_results if max_results else 100
) # Arbitrarily selected...
jobs_paginated = jobs[start_offset:end_offset]
response = {
"TranscriptionJobSummaries": [
job.response_object("LIST") for job in jobs_paginated
]
}
if end_offset < len(jobs):
response["NextToken"] = str(end_offset)
if state_equals:
response["Status"] = state_equals
return response
def list_medical_transcription_jobs(
self, status, job_name_contains, next_token, max_results
):
jobs = list(self.medical_transcriptions.values())
if status:
jobs = [job for job in jobs if job.status == status]
if job_name_contains:
jobs = [
job
for job in jobs
if job_name_contains in job.medical_transcription_job_name
]
start_offset = int(next_token) if next_token else 0
end_offset = start_offset + (
max_results if max_results else 100
) # Arbitrarily selected...
jobs_paginated = jobs[start_offset:end_offset]
response = {
"MedicalTranscriptionJobSummaries": [
job.response_object("LIST") for job in jobs_paginated
]
}
if end_offset < len(jobs):
response["NextToken"] = str(end_offset)
if status:
response["Status"] = status
return response
def create_vocabulary(self, **kwargs):
vocabulary_name = kwargs.get("vocabulary_name")
language_code = kwargs.get("language_code")
phrases = kwargs.get("phrases")
vocabulary_file_uri = kwargs.get("vocabulary_file_uri")
if (
phrases is not None
and vocabulary_file_uri is not None
or phrases is None
and vocabulary_file_uri is None
):
raise BadRequestException(
message="Either Phrases or VocabularyFileUri field should be provided."
)
if phrases is not None and len(phrases) < 1:
raise BadRequestException(
message="1 validation error detected: Value '[]' at 'phrases' failed to "
"satisfy constraint: Member must have length greater than or "
"equal to 1"
)
if vocabulary_name in self.vocabularies:
raise ConflictException(
message="The requested vocabulary name already exists. "
"Use a different vocabulary name."
)
vocabulary_object = FakeVocabulary(
2022-08-13 09:49:43 +00:00
account_id=self.account_id,
region_name=self.region_name,
vocabulary_name=vocabulary_name,
language_code=language_code,
phrases=phrases,
vocabulary_file_uri=vocabulary_file_uri,
)
self.vocabularies[vocabulary_name] = vocabulary_object
return vocabulary_object.response_object("CREATE")
def create_medical_vocabulary(self, **kwargs):
vocabulary_name = kwargs.get("vocabulary_name")
language_code = kwargs.get("language_code")
vocabulary_file_uri = kwargs.get("vocabulary_file_uri")
if vocabulary_name in self.medical_vocabularies:
raise ConflictException(
message="The requested vocabulary name already exists. "
"Use a different vocabulary name."
)
medical_vocabulary_object = FakeMedicalVocabulary(
2022-08-13 09:49:43 +00:00
account_id=self.account_id,
region_name=self.region_name,
vocabulary_name=vocabulary_name,
language_code=language_code,
vocabulary_file_uri=vocabulary_file_uri,
)
self.medical_vocabularies[vocabulary_name] = medical_vocabulary_object
return medical_vocabulary_object.response_object("CREATE")
def get_vocabulary(self, vocabulary_name):
try:
job = self.vocabularies[vocabulary_name]
job.advance() # Fakes advancement through statuses.
return job.response_object("GET")
except KeyError:
raise BadRequestException(
message="The requested vocabulary couldn't be found. "
"Check the vocabulary name and try your request again."
)
def get_medical_vocabulary(self, vocabulary_name):
try:
job = self.medical_vocabularies[vocabulary_name]
job.advance() # Fakes advancement through statuses.
return job.response_object("GET")
except KeyError:
raise BadRequestException(
message="The requested vocabulary couldn't be found. "
"Check the vocabulary name and try your request again."
)
def delete_vocabulary(self, vocabulary_name):
try:
del self.vocabularies[vocabulary_name]
except KeyError:
raise BadRequestException(
message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again."
)
def delete_medical_vocabulary(self, vocabulary_name):
try:
del self.medical_vocabularies[vocabulary_name]
except KeyError:
raise BadRequestException(
message="The requested vocabulary couldn't be found. Check the vocabulary name and try your request again."
)
def list_vocabularies(self, state_equals, name_contains, next_token, max_results):
vocabularies = list(self.vocabularies.values())
if state_equals:
vocabularies = [
vocabulary
for vocabulary in vocabularies
if vocabulary.status == state_equals
]
if name_contains:
vocabularies = [
vocabulary
for vocabulary in vocabularies
if name_contains in vocabulary.vocabulary_name
]
start_offset = int(next_token) if next_token else 0
end_offset = start_offset + (
max_results if max_results else 100
) # Arbitrarily selected...
vocabularies_paginated = vocabularies[start_offset:end_offset]
response = {
"Vocabularies": [
vocabulary.response_object("LIST")
for vocabulary in vocabularies_paginated
]
}
if end_offset < len(vocabularies):
response["NextToken"] = str(end_offset)
if state_equals:
response["Status"] = state_equals
return response
def list_medical_vocabularies(
self, state_equals, name_contains, next_token, max_results
):
vocabularies = list(self.medical_vocabularies.values())
if state_equals:
vocabularies = [
vocabulary
for vocabulary in vocabularies
if vocabulary.status == state_equals
]
if name_contains:
vocabularies = [
vocabulary
for vocabulary in vocabularies
if name_contains in vocabulary.vocabulary_name
]
start_offset = int(next_token) if next_token else 0
end_offset = start_offset + (
max_results if max_results else 100
) # Arbitrarily selected...
vocabularies_paginated = vocabularies[start_offset:end_offset]
response = {
"Vocabularies": [
vocabulary.response_object("LIST")
for vocabulary in vocabularies_paginated
]
}
if end_offset < len(vocabularies):
response["NextToken"] = str(end_offset)
if state_equals:
response["Status"] = state_equals
return response
2021-12-24 21:02:45 +00:00
transcribe_backends = BackendDict(TranscribeBackend, "transcribe")