diff --git a/moto/transcribe/models.py b/moto/transcribe/models.py index 63a9d648f..6442c0eb9 100644 --- a/moto/transcribe/models.py +++ b/moto/transcribe/models.py @@ -1,3 +1,4 @@ +import re from datetime import datetime, timedelta from typing import Any, Dict, List, Optional from moto.core import BaseBackend, BackendDict, BaseModel @@ -48,6 +49,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): identify_language: Optional[bool], identify_multiple_languages: Optional[bool], language_options: Optional[List[str]], + subtitles: Optional[Dict[str, Any]], ): ManagedState.__init__( self, @@ -95,6 +97,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): self.output_location_type = ( "CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET" ) + self.subtitles = subtitles or {"Formats": [], "OutputStartIndex": 0} def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore response_field_dict = { @@ -112,6 +115,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): "IdentifyMultipleLanguages", "LanguageOptions", "JobExecutionSettings", + "Subtitles", ], "GET": [ "TranscriptionJobName", @@ -130,6 +134,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): "IdentifyMultipleLanguages", "LanguageOptions", "IdentifiedLanguageScore", + "Subtitles", ], "LIST": [ "TranscriptionJobName", @@ -217,15 +222,28 @@ class FakeTranscriptionJob(BaseObject, ManagedState): "%Y-%m-%d %H:%M:%S" ) if self._output_bucket_name: - transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/{self._output_bucket_name}/" - if self.output_key is not None: - transcript_file_uri += f"{self.output_key}/" - transcript_file_uri += f"{self.transcription_job_name}.json" + remove_json_extension = re.compile("\\.json$") + transcript_file_prefix = ( + f"https://s3.{self._region_name}.amazonaws.com/" + f"{self._output_bucket_name}/" + f"{remove_json_extension.sub('', self.output_key or self.transcription_job_name)}" + ) self.output_location_type = "CUSTOMER_BUCKET" else: - transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/aws-transcribe-{self._region_name}-prod/{self._account_id}/{self.transcription_job_name}/{mock_random.uuid4()}/asrOutput.json" + transcript_file_prefix = ( + f"https://s3.{self._region_name}.amazonaws.com/" + f"aws-transcribe-{self._region_name}-prod/" + f"{self._account_id}/" + f"{self.transcription_job_name}/" + f"{mock_random.uuid4()}/" + "asrOutput" + ) self.output_location_type = "SERVICE_BUCKET" - self.transcript = {"TranscriptFileUri": transcript_file_uri} + self.transcript = {"TranscriptFileUri": f"{transcript_file_prefix}.json"} + self.subtitles["SubtitleFileUris"] = [ + f"{transcript_file_prefix}.{format}" + for format in self.subtitles["Formats"] + ] class FakeVocabulary(BaseObject, ManagedState): @@ -504,6 +522,7 @@ class TranscribeBackend(BaseBackend): identify_language: Optional[bool], identify_multiple_languages: Optional[bool], language_options: Optional[List[str]], + subtitles: Optional[Dict[str, Any]], ) -> Dict[str, Any]: if transcription_job_name in self.transcriptions: raise ConflictException( @@ -535,6 +554,7 @@ class TranscribeBackend(BaseBackend): identify_language=identify_language, identify_multiple_languages=identify_multiple_languages, language_options=language_options, + subtitles=subtitles, ) self.transcriptions[transcription_job_name] = transcription_job_object diff --git a/moto/transcribe/responses.py b/moto/transcribe/responses.py index a74771c3a..593837c75 100644 --- a/moto/transcribe/responses.py +++ b/moto/transcribe/responses.py @@ -32,6 +32,7 @@ class TranscribeResponse(BaseResponse): identify_language=self._get_param("IdentifyLanguage"), identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"), language_options=self._get_param("LanguageOptions"), + subtitles=self._get_param("Subtitles"), ) return json.dumps(response) diff --git a/tests/test_transcribe/test_transcribe_boto3.py b/tests/test_transcribe/test_transcribe_boto3.py index a39d8156d..6d62288a7 100644 --- a/tests/test_transcribe/test_transcribe_boto3.py +++ b/tests/test_transcribe/test_transcribe_boto3.py @@ -209,6 +209,10 @@ def test_run_transcription_job_all_params(): "MaxAlternatives": 6, "VocabularyName": vocabulary_name, }, + "Subtitles": { + "Formats": ["srt", "vtt"], + "OutputStartIndex": 1, + }, # Missing `ContentRedaction`, `JobExecutionSettings`, # `VocabularyFilterName`, `LanguageModel` } @@ -269,6 +273,18 @@ def test_run_transcription_job_all_params(): f"/{args['TranscriptionJobName']}.json" ), } + assert transcription_job["Subtitles"] == { + "Formats": args["Subtitles"]["Formats"], + "OutputStartIndex": 1, + "SubtitleFileUris": [ + ( + f"https://s3.{region_name}.amazonaws.com" + f"/{args['OutputBucketName']}" + f"/{args['TranscriptionJobName']}.{format}" + ) + for format in args["Subtitles"]["Formats"] + ], + } @mock_transcribe @@ -329,6 +345,11 @@ def test_run_transcription_job_minimal_params(): assert ( f"https://s3.{region_name}.amazonaws.com/aws-transcribe-{region_name}-prod/" ) in transcription_job["Transcript"]["TranscriptFileUri"] + assert transcription_job["Subtitles"] == { + "Formats": [], + "OutputStartIndex": 0, + "SubtitleFileUris": [], + } # Delete client.delete_transcription_job(TranscriptionJobName=job_name) @@ -348,7 +369,8 @@ def test_run_transcription_job_s3output_params(): "LanguageCode": "en-US", "Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"}, "OutputBucketName": "my-output-bucket", - "OutputKey": "bucket-key", + "OutputKey": "bucket.json.key.json", + "Subtitles": {"Formats": ["vtt", "srt"]}, } resp = client.start_transcription_job(**args) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 @@ -379,8 +401,16 @@ def test_run_transcription_job_s3output_params(): assert "Transcript" in transcription_job # Check aws hosted bucket assert ( - "https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket-key/MyJob.json" + "https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.json" ) in transcription_job["Transcript"]["TranscriptFileUri"] + assert transcription_job["Subtitles"] == { + "Formats": args["Subtitles"]["Formats"], + "SubtitleFileUris": [ + f"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.{format}" + for format in args["Subtitles"]["Formats"] + ], + } + # A new job without an "OutputKey" job_name = "MyJob2" args = {