Add subtitle support for transcribe (#7028)

This commit is contained in:
Ross Cooperman 2023-11-16 05:15:56 -05:00 committed by GitHub
parent 63e869d717
commit 86f1d53f54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 8 deletions

View File

@ -1,3 +1,4 @@
import re
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
@ -48,6 +49,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
identify_language: Optional[bool],
identify_multiple_languages: Optional[bool],
language_options: Optional[List[str]],
subtitles: Optional[Dict[str, Any]],
):
ManagedState.__init__(
self,
@ -95,6 +97,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self.output_location_type = (
"CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
)
self.subtitles = subtitles or {"Formats": [], "OutputStartIndex": 0}
def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
response_field_dict = {
@ -112,6 +115,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"IdentifyMultipleLanguages",
"LanguageOptions",
"JobExecutionSettings",
"Subtitles",
],
"GET": [
"TranscriptionJobName",
@ -130,6 +134,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"IdentifyMultipleLanguages",
"LanguageOptions",
"IdentifiedLanguageScore",
"Subtitles",
],
"LIST": [
"TranscriptionJobName",
@ -217,15 +222,28 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"%Y-%m-%d %H:%M:%S"
)
if self._output_bucket_name:
transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/{self._output_bucket_name}/"
if self.output_key is not None:
transcript_file_uri += f"{self.output_key}/"
transcript_file_uri += f"{self.transcription_job_name}.json"
remove_json_extension = re.compile("\\.json$")
transcript_file_prefix = (
f"https://s3.{self._region_name}.amazonaws.com/"
f"{self._output_bucket_name}/"
f"{remove_json_extension.sub('', self.output_key or self.transcription_job_name)}"
)
self.output_location_type = "CUSTOMER_BUCKET"
else:
transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/aws-transcribe-{self._region_name}-prod/{self._account_id}/{self.transcription_job_name}/{mock_random.uuid4()}/asrOutput.json"
transcript_file_prefix = (
f"https://s3.{self._region_name}.amazonaws.com/"
f"aws-transcribe-{self._region_name}-prod/"
f"{self._account_id}/"
f"{self.transcription_job_name}/"
f"{mock_random.uuid4()}/"
"asrOutput"
)
self.output_location_type = "SERVICE_BUCKET"
self.transcript = {"TranscriptFileUri": transcript_file_uri}
self.transcript = {"TranscriptFileUri": f"{transcript_file_prefix}.json"}
self.subtitles["SubtitleFileUris"] = [
f"{transcript_file_prefix}.{format}"
for format in self.subtitles["Formats"]
]
class FakeVocabulary(BaseObject, ManagedState):
@ -504,6 +522,7 @@ class TranscribeBackend(BaseBackend):
identify_language: Optional[bool],
identify_multiple_languages: Optional[bool],
language_options: Optional[List[str]],
subtitles: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
if transcription_job_name in self.transcriptions:
raise ConflictException(
@ -535,6 +554,7 @@ class TranscribeBackend(BaseBackend):
identify_language=identify_language,
identify_multiple_languages=identify_multiple_languages,
language_options=language_options,
subtitles=subtitles,
)
self.transcriptions[transcription_job_name] = transcription_job_object

View File

@ -32,6 +32,7 @@ class TranscribeResponse(BaseResponse):
identify_language=self._get_param("IdentifyLanguage"),
identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"),
language_options=self._get_param("LanguageOptions"),
subtitles=self._get_param("Subtitles"),
)
return json.dumps(response)

View File

@ -209,6 +209,10 @@ def test_run_transcription_job_all_params():
"MaxAlternatives": 6,
"VocabularyName": vocabulary_name,
},
"Subtitles": {
"Formats": ["srt", "vtt"],
"OutputStartIndex": 1,
},
# Missing `ContentRedaction`, `JobExecutionSettings`,
# `VocabularyFilterName`, `LanguageModel`
}
@ -269,6 +273,18 @@ def test_run_transcription_job_all_params():
f"/{args['TranscriptionJobName']}.json"
),
}
assert transcription_job["Subtitles"] == {
"Formats": args["Subtitles"]["Formats"],
"OutputStartIndex": 1,
"SubtitleFileUris": [
(
f"https://s3.{region_name}.amazonaws.com"
f"/{args['OutputBucketName']}"
f"/{args['TranscriptionJobName']}.{format}"
)
for format in args["Subtitles"]["Formats"]
],
}
@mock_transcribe
@ -329,6 +345,11 @@ def test_run_transcription_job_minimal_params():
assert (
f"https://s3.{region_name}.amazonaws.com/aws-transcribe-{region_name}-prod/"
) in transcription_job["Transcript"]["TranscriptFileUri"]
assert transcription_job["Subtitles"] == {
"Formats": [],
"OutputStartIndex": 0,
"SubtitleFileUris": [],
}
# Delete
client.delete_transcription_job(TranscriptionJobName=job_name)
@ -348,7 +369,8 @@ def test_run_transcription_job_s3output_params():
"LanguageCode": "en-US",
"Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"},
"OutputBucketName": "my-output-bucket",
"OutputKey": "bucket-key",
"OutputKey": "bucket.json.key.json",
"Subtitles": {"Formats": ["vtt", "srt"]},
}
resp = client.start_transcription_job(**args)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
@ -379,8 +401,16 @@ def test_run_transcription_job_s3output_params():
assert "Transcript" in transcription_job
# Check aws hosted bucket
assert (
"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket-key/MyJob.json"
"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.json"
) in transcription_job["Transcript"]["TranscriptFileUri"]
assert transcription_job["Subtitles"] == {
"Formats": args["Subtitles"]["Formats"],
"SubtitleFileUris": [
f"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.{format}"
for format in args["Subtitles"]["Formats"]
],
}
# A new job without an "OutputKey"
job_name = "MyJob2"
args = {