Add subtitle support for transcribe
(#7028)
This commit is contained in:
parent
63e869d717
commit
86f1d53f54
@ -1,3 +1,4 @@
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel
|
||||
@ -48,6 +49,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
||||
identify_language: Optional[bool],
|
||||
identify_multiple_languages: Optional[bool],
|
||||
language_options: Optional[List[str]],
|
||||
subtitles: Optional[Dict[str, Any]],
|
||||
):
|
||||
ManagedState.__init__(
|
||||
self,
|
||||
@ -95,6 +97,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
||||
self.output_location_type = (
|
||||
"CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
|
||||
)
|
||||
self.subtitles = subtitles or {"Formats": [], "OutputStartIndex": 0}
|
||||
|
||||
def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
|
||||
response_field_dict = {
|
||||
@ -112,6 +115,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
||||
"IdentifyMultipleLanguages",
|
||||
"LanguageOptions",
|
||||
"JobExecutionSettings",
|
||||
"Subtitles",
|
||||
],
|
||||
"GET": [
|
||||
"TranscriptionJobName",
|
||||
@ -130,6 +134,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
||||
"IdentifyMultipleLanguages",
|
||||
"LanguageOptions",
|
||||
"IdentifiedLanguageScore",
|
||||
"Subtitles",
|
||||
],
|
||||
"LIST": [
|
||||
"TranscriptionJobName",
|
||||
@ -217,15 +222,28 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
if self._output_bucket_name:
|
||||
transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/{self._output_bucket_name}/"
|
||||
if self.output_key is not None:
|
||||
transcript_file_uri += f"{self.output_key}/"
|
||||
transcript_file_uri += f"{self.transcription_job_name}.json"
|
||||
remove_json_extension = re.compile("\\.json$")
|
||||
transcript_file_prefix = (
|
||||
f"https://s3.{self._region_name}.amazonaws.com/"
|
||||
f"{self._output_bucket_name}/"
|
||||
f"{remove_json_extension.sub('', self.output_key or self.transcription_job_name)}"
|
||||
)
|
||||
self.output_location_type = "CUSTOMER_BUCKET"
|
||||
else:
|
||||
transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/aws-transcribe-{self._region_name}-prod/{self._account_id}/{self.transcription_job_name}/{mock_random.uuid4()}/asrOutput.json"
|
||||
transcript_file_prefix = (
|
||||
f"https://s3.{self._region_name}.amazonaws.com/"
|
||||
f"aws-transcribe-{self._region_name}-prod/"
|
||||
f"{self._account_id}/"
|
||||
f"{self.transcription_job_name}/"
|
||||
f"{mock_random.uuid4()}/"
|
||||
"asrOutput"
|
||||
)
|
||||
self.output_location_type = "SERVICE_BUCKET"
|
||||
self.transcript = {"TranscriptFileUri": transcript_file_uri}
|
||||
self.transcript = {"TranscriptFileUri": f"{transcript_file_prefix}.json"}
|
||||
self.subtitles["SubtitleFileUris"] = [
|
||||
f"{transcript_file_prefix}.{format}"
|
||||
for format in self.subtitles["Formats"]
|
||||
]
|
||||
|
||||
|
||||
class FakeVocabulary(BaseObject, ManagedState):
|
||||
@ -504,6 +522,7 @@ class TranscribeBackend(BaseBackend):
|
||||
identify_language: Optional[bool],
|
||||
identify_multiple_languages: Optional[bool],
|
||||
language_options: Optional[List[str]],
|
||||
subtitles: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
if transcription_job_name in self.transcriptions:
|
||||
raise ConflictException(
|
||||
@ -535,6 +554,7 @@ class TranscribeBackend(BaseBackend):
|
||||
identify_language=identify_language,
|
||||
identify_multiple_languages=identify_multiple_languages,
|
||||
language_options=language_options,
|
||||
subtitles=subtitles,
|
||||
)
|
||||
self.transcriptions[transcription_job_name] = transcription_job_object
|
||||
|
||||
|
@ -32,6 +32,7 @@ class TranscribeResponse(BaseResponse):
|
||||
identify_language=self._get_param("IdentifyLanguage"),
|
||||
identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"),
|
||||
language_options=self._get_param("LanguageOptions"),
|
||||
subtitles=self._get_param("Subtitles"),
|
||||
)
|
||||
return json.dumps(response)
|
||||
|
||||
|
@ -209,6 +209,10 @@ def test_run_transcription_job_all_params():
|
||||
"MaxAlternatives": 6,
|
||||
"VocabularyName": vocabulary_name,
|
||||
},
|
||||
"Subtitles": {
|
||||
"Formats": ["srt", "vtt"],
|
||||
"OutputStartIndex": 1,
|
||||
},
|
||||
# Missing `ContentRedaction`, `JobExecutionSettings`,
|
||||
# `VocabularyFilterName`, `LanguageModel`
|
||||
}
|
||||
@ -269,6 +273,18 @@ def test_run_transcription_job_all_params():
|
||||
f"/{args['TranscriptionJobName']}.json"
|
||||
),
|
||||
}
|
||||
assert transcription_job["Subtitles"] == {
|
||||
"Formats": args["Subtitles"]["Formats"],
|
||||
"OutputStartIndex": 1,
|
||||
"SubtitleFileUris": [
|
||||
(
|
||||
f"https://s3.{region_name}.amazonaws.com"
|
||||
f"/{args['OutputBucketName']}"
|
||||
f"/{args['TranscriptionJobName']}.{format}"
|
||||
)
|
||||
for format in args["Subtitles"]["Formats"]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@mock_transcribe
|
||||
@ -329,6 +345,11 @@ def test_run_transcription_job_minimal_params():
|
||||
assert (
|
||||
f"https://s3.{region_name}.amazonaws.com/aws-transcribe-{region_name}-prod/"
|
||||
) in transcription_job["Transcript"]["TranscriptFileUri"]
|
||||
assert transcription_job["Subtitles"] == {
|
||||
"Formats": [],
|
||||
"OutputStartIndex": 0,
|
||||
"SubtitleFileUris": [],
|
||||
}
|
||||
|
||||
# Delete
|
||||
client.delete_transcription_job(TranscriptionJobName=job_name)
|
||||
@ -348,7 +369,8 @@ def test_run_transcription_job_s3output_params():
|
||||
"LanguageCode": "en-US",
|
||||
"Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"},
|
||||
"OutputBucketName": "my-output-bucket",
|
||||
"OutputKey": "bucket-key",
|
||||
"OutputKey": "bucket.json.key.json",
|
||||
"Subtitles": {"Formats": ["vtt", "srt"]},
|
||||
}
|
||||
resp = client.start_transcription_job(**args)
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
@ -379,8 +401,16 @@ def test_run_transcription_job_s3output_params():
|
||||
assert "Transcript" in transcription_job
|
||||
# Check aws hosted bucket
|
||||
assert (
|
||||
"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket-key/MyJob.json"
|
||||
"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.json"
|
||||
) in transcription_job["Transcript"]["TranscriptFileUri"]
|
||||
assert transcription_job["Subtitles"] == {
|
||||
"Formats": args["Subtitles"]["Formats"],
|
||||
"SubtitleFileUris": [
|
||||
f"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.{format}"
|
||||
for format in args["Subtitles"]["Formats"]
|
||||
],
|
||||
}
|
||||
|
||||
# A new job without an "OutputKey"
|
||||
job_name = "MyJob2"
|
||||
args = {
|
||||
|
Loading…
Reference in New Issue
Block a user