Add subtitle support for transcribe (#7028)

This commit is contained in:
Ross Cooperman 2023-11-16 05:15:56 -05:00 committed by GitHub
parent 63e869d717
commit 86f1d53f54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 8 deletions

View File

@ -1,3 +1,4 @@
import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
@ -48,6 +49,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
identify_language: Optional[bool], identify_language: Optional[bool],
identify_multiple_languages: Optional[bool], identify_multiple_languages: Optional[bool],
language_options: Optional[List[str]], language_options: Optional[List[str]],
subtitles: Optional[Dict[str, Any]],
): ):
ManagedState.__init__( ManagedState.__init__(
self, self,
@ -95,6 +97,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self.output_location_type = ( self.output_location_type = (
"CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET" "CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
) )
self.subtitles = subtitles or {"Formats": [], "OutputStartIndex": 0}
def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
response_field_dict = { response_field_dict = {
@ -112,6 +115,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"IdentifyMultipleLanguages", "IdentifyMultipleLanguages",
"LanguageOptions", "LanguageOptions",
"JobExecutionSettings", "JobExecutionSettings",
"Subtitles",
], ],
"GET": [ "GET": [
"TranscriptionJobName", "TranscriptionJobName",
@ -130,6 +134,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"IdentifyMultipleLanguages", "IdentifyMultipleLanguages",
"LanguageOptions", "LanguageOptions",
"IdentifiedLanguageScore", "IdentifiedLanguageScore",
"Subtitles",
], ],
"LIST": [ "LIST": [
"TranscriptionJobName", "TranscriptionJobName",
@ -217,15 +222,28 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
) )
if self._output_bucket_name: if self._output_bucket_name:
transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/{self._output_bucket_name}/" remove_json_extension = re.compile("\\.json$")
if self.output_key is not None: transcript_file_prefix = (
transcript_file_uri += f"{self.output_key}/" f"https://s3.{self._region_name}.amazonaws.com/"
transcript_file_uri += f"{self.transcription_job_name}.json" f"{self._output_bucket_name}/"
f"{remove_json_extension.sub('', self.output_key or self.transcription_job_name)}"
)
self.output_location_type = "CUSTOMER_BUCKET" self.output_location_type = "CUSTOMER_BUCKET"
else: else:
transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/aws-transcribe-{self._region_name}-prod/{self._account_id}/{self.transcription_job_name}/{mock_random.uuid4()}/asrOutput.json" transcript_file_prefix = (
f"https://s3.{self._region_name}.amazonaws.com/"
f"aws-transcribe-{self._region_name}-prod/"
f"{self._account_id}/"
f"{self.transcription_job_name}/"
f"{mock_random.uuid4()}/"
"asrOutput"
)
self.output_location_type = "SERVICE_BUCKET" self.output_location_type = "SERVICE_BUCKET"
self.transcript = {"TranscriptFileUri": transcript_file_uri} self.transcript = {"TranscriptFileUri": f"{transcript_file_prefix}.json"}
self.subtitles["SubtitleFileUris"] = [
f"{transcript_file_prefix}.{format}"
for format in self.subtitles["Formats"]
]
class FakeVocabulary(BaseObject, ManagedState): class FakeVocabulary(BaseObject, ManagedState):
@ -504,6 +522,7 @@ class TranscribeBackend(BaseBackend):
identify_language: Optional[bool], identify_language: Optional[bool],
identify_multiple_languages: Optional[bool], identify_multiple_languages: Optional[bool],
language_options: Optional[List[str]], language_options: Optional[List[str]],
subtitles: Optional[Dict[str, Any]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if transcription_job_name in self.transcriptions: if transcription_job_name in self.transcriptions:
raise ConflictException( raise ConflictException(
@ -535,6 +554,7 @@ class TranscribeBackend(BaseBackend):
identify_language=identify_language, identify_language=identify_language,
identify_multiple_languages=identify_multiple_languages, identify_multiple_languages=identify_multiple_languages,
language_options=language_options, language_options=language_options,
subtitles=subtitles,
) )
self.transcriptions[transcription_job_name] = transcription_job_object self.transcriptions[transcription_job_name] = transcription_job_object

View File

@ -32,6 +32,7 @@ class TranscribeResponse(BaseResponse):
identify_language=self._get_param("IdentifyLanguage"), identify_language=self._get_param("IdentifyLanguage"),
identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"), identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"),
language_options=self._get_param("LanguageOptions"), language_options=self._get_param("LanguageOptions"),
subtitles=self._get_param("Subtitles"),
) )
return json.dumps(response) return json.dumps(response)

View File

@ -209,6 +209,10 @@ def test_run_transcription_job_all_params():
"MaxAlternatives": 6, "MaxAlternatives": 6,
"VocabularyName": vocabulary_name, "VocabularyName": vocabulary_name,
}, },
"Subtitles": {
"Formats": ["srt", "vtt"],
"OutputStartIndex": 1,
},
# Missing `ContentRedaction`, `JobExecutionSettings`, # Missing `ContentRedaction`, `JobExecutionSettings`,
# `VocabularyFilterName`, `LanguageModel` # `VocabularyFilterName`, `LanguageModel`
} }
@ -269,6 +273,18 @@ def test_run_transcription_job_all_params():
f"/{args['TranscriptionJobName']}.json" f"/{args['TranscriptionJobName']}.json"
), ),
} }
assert transcription_job["Subtitles"] == {
"Formats": args["Subtitles"]["Formats"],
"OutputStartIndex": 1,
"SubtitleFileUris": [
(
f"https://s3.{region_name}.amazonaws.com"
f"/{args['OutputBucketName']}"
f"/{args['TranscriptionJobName']}.{format}"
)
for format in args["Subtitles"]["Formats"]
],
}
@mock_transcribe @mock_transcribe
@ -329,6 +345,11 @@ def test_run_transcription_job_minimal_params():
assert ( assert (
f"https://s3.{region_name}.amazonaws.com/aws-transcribe-{region_name}-prod/" f"https://s3.{region_name}.amazonaws.com/aws-transcribe-{region_name}-prod/"
) in transcription_job["Transcript"]["TranscriptFileUri"] ) in transcription_job["Transcript"]["TranscriptFileUri"]
assert transcription_job["Subtitles"] == {
"Formats": [],
"OutputStartIndex": 0,
"SubtitleFileUris": [],
}
# Delete # Delete
client.delete_transcription_job(TranscriptionJobName=job_name) client.delete_transcription_job(TranscriptionJobName=job_name)
@ -348,7 +369,8 @@ def test_run_transcription_job_s3output_params():
"LanguageCode": "en-US", "LanguageCode": "en-US",
"Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"}, "Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"},
"OutputBucketName": "my-output-bucket", "OutputBucketName": "my-output-bucket",
"OutputKey": "bucket-key", "OutputKey": "bucket.json.key.json",
"Subtitles": {"Formats": ["vtt", "srt"]},
} }
resp = client.start_transcription_job(**args) resp = client.start_transcription_job(**args)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
@ -379,8 +401,16 @@ def test_run_transcription_job_s3output_params():
assert "Transcript" in transcription_job assert "Transcript" in transcription_job
# Check aws hosted bucket # Check aws hosted bucket
assert ( assert (
"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket-key/MyJob.json" "https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.json"
) in transcription_job["Transcript"]["TranscriptFileUri"] ) in transcription_job["Transcript"]["TranscriptFileUri"]
assert transcription_job["Subtitles"] == {
"Formats": args["Subtitles"]["Formats"],
"SubtitleFileUris": [
f"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.{format}"
for format in args["Subtitles"]["Formats"]
],
}
# A new job without an "OutputKey" # A new job without an "OutputKey"
job_name = "MyJob2" job_name = "MyJob2"
args = { args = {