Add subtitle support for transcribe (#7028)
				
					
				
			This commit is contained in:
		
							parent
							
								
									63e869d717
								
							
						
					
					
						commit
						86f1d53f54
					
				@ -1,3 +1,4 @@
 | 
			
		||||
import re
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from typing import Any, Dict, List, Optional
 | 
			
		||||
from moto.core import BaseBackend, BackendDict, BaseModel
 | 
			
		||||
@ -48,6 +49,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
 | 
			
		||||
        identify_language: Optional[bool],
 | 
			
		||||
        identify_multiple_languages: Optional[bool],
 | 
			
		||||
        language_options: Optional[List[str]],
 | 
			
		||||
        subtitles: Optional[Dict[str, Any]],
 | 
			
		||||
    ):
 | 
			
		||||
        ManagedState.__init__(
 | 
			
		||||
            self,
 | 
			
		||||
@ -95,6 +97,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
 | 
			
		||||
        self.output_location_type = (
 | 
			
		||||
            "CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
 | 
			
		||||
        )
 | 
			
		||||
        self.subtitles = subtitles or {"Formats": [], "OutputStartIndex": 0}
 | 
			
		||||
 | 
			
		||||
    def response_object(self, response_type: str) -> Dict[str, Any]:  # type: ignore
 | 
			
		||||
        response_field_dict = {
 | 
			
		||||
@ -112,6 +115,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
 | 
			
		||||
                "IdentifyMultipleLanguages",
 | 
			
		||||
                "LanguageOptions",
 | 
			
		||||
                "JobExecutionSettings",
 | 
			
		||||
                "Subtitles",
 | 
			
		||||
            ],
 | 
			
		||||
            "GET": [
 | 
			
		||||
                "TranscriptionJobName",
 | 
			
		||||
@ -130,6 +134,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
 | 
			
		||||
                "IdentifyMultipleLanguages",
 | 
			
		||||
                "LanguageOptions",
 | 
			
		||||
                "IdentifiedLanguageScore",
 | 
			
		||||
                "Subtitles",
 | 
			
		||||
            ],
 | 
			
		||||
            "LIST": [
 | 
			
		||||
                "TranscriptionJobName",
 | 
			
		||||
@ -217,15 +222,28 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
 | 
			
		||||
                "%Y-%m-%d %H:%M:%S"
 | 
			
		||||
            )
 | 
			
		||||
            if self._output_bucket_name:
 | 
			
		||||
                transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/{self._output_bucket_name}/"
 | 
			
		||||
                if self.output_key is not None:
 | 
			
		||||
                    transcript_file_uri += f"{self.output_key}/"
 | 
			
		||||
                transcript_file_uri += f"{self.transcription_job_name}.json"
 | 
			
		||||
                remove_json_extension = re.compile("\\.json$")
 | 
			
		||||
                transcript_file_prefix = (
 | 
			
		||||
                    f"https://s3.{self._region_name}.amazonaws.com/"
 | 
			
		||||
                    f"{self._output_bucket_name}/"
 | 
			
		||||
                    f"{remove_json_extension.sub('', self.output_key or self.transcription_job_name)}"
 | 
			
		||||
                )
 | 
			
		||||
                self.output_location_type = "CUSTOMER_BUCKET"
 | 
			
		||||
            else:
 | 
			
		||||
                transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/aws-transcribe-{self._region_name}-prod/{self._account_id}/{self.transcription_job_name}/{mock_random.uuid4()}/asrOutput.json"
 | 
			
		||||
                transcript_file_prefix = (
 | 
			
		||||
                    f"https://s3.{self._region_name}.amazonaws.com/"
 | 
			
		||||
                    f"aws-transcribe-{self._region_name}-prod/"
 | 
			
		||||
                    f"{self._account_id}/"
 | 
			
		||||
                    f"{self.transcription_job_name}/"
 | 
			
		||||
                    f"{mock_random.uuid4()}/"
 | 
			
		||||
                    "asrOutput"
 | 
			
		||||
                )
 | 
			
		||||
                self.output_location_type = "SERVICE_BUCKET"
 | 
			
		||||
            self.transcript = {"TranscriptFileUri": transcript_file_uri}
 | 
			
		||||
            self.transcript = {"TranscriptFileUri": f"{transcript_file_prefix}.json"}
 | 
			
		||||
            self.subtitles["SubtitleFileUris"] = [
 | 
			
		||||
                f"{transcript_file_prefix}.{format}"
 | 
			
		||||
                for format in self.subtitles["Formats"]
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeVocabulary(BaseObject, ManagedState):
 | 
			
		||||
@ -504,6 +522,7 @@ class TranscribeBackend(BaseBackend):
 | 
			
		||||
        identify_language: Optional[bool],
 | 
			
		||||
        identify_multiple_languages: Optional[bool],
 | 
			
		||||
        language_options: Optional[List[str]],
 | 
			
		||||
        subtitles: Optional[Dict[str, Any]],
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        if transcription_job_name in self.transcriptions:
 | 
			
		||||
            raise ConflictException(
 | 
			
		||||
@ -535,6 +554,7 @@ class TranscribeBackend(BaseBackend):
 | 
			
		||||
            identify_language=identify_language,
 | 
			
		||||
            identify_multiple_languages=identify_multiple_languages,
 | 
			
		||||
            language_options=language_options,
 | 
			
		||||
            subtitles=subtitles,
 | 
			
		||||
        )
 | 
			
		||||
        self.transcriptions[transcription_job_name] = transcription_job_object
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -32,6 +32,7 @@ class TranscribeResponse(BaseResponse):
 | 
			
		||||
            identify_language=self._get_param("IdentifyLanguage"),
 | 
			
		||||
            identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"),
 | 
			
		||||
            language_options=self._get_param("LanguageOptions"),
 | 
			
		||||
            subtitles=self._get_param("Subtitles"),
 | 
			
		||||
        )
 | 
			
		||||
        return json.dumps(response)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -209,6 +209,10 @@ def test_run_transcription_job_all_params():
 | 
			
		||||
            "MaxAlternatives": 6,
 | 
			
		||||
            "VocabularyName": vocabulary_name,
 | 
			
		||||
        },
 | 
			
		||||
        "Subtitles": {
 | 
			
		||||
            "Formats": ["srt", "vtt"],
 | 
			
		||||
            "OutputStartIndex": 1,
 | 
			
		||||
        },
 | 
			
		||||
        # Missing `ContentRedaction`, `JobExecutionSettings`,
 | 
			
		||||
        # `VocabularyFilterName`, `LanguageModel`
 | 
			
		||||
    }
 | 
			
		||||
@ -269,6 +273,18 @@ def test_run_transcription_job_all_params():
 | 
			
		||||
            f"/{args['TranscriptionJobName']}.json"
 | 
			
		||||
        ),
 | 
			
		||||
    }
 | 
			
		||||
    assert transcription_job["Subtitles"] == {
 | 
			
		||||
        "Formats": args["Subtitles"]["Formats"],
 | 
			
		||||
        "OutputStartIndex": 1,
 | 
			
		||||
        "SubtitleFileUris": [
 | 
			
		||||
            (
 | 
			
		||||
                f"https://s3.{region_name}.amazonaws.com"
 | 
			
		||||
                f"/{args['OutputBucketName']}"
 | 
			
		||||
                f"/{args['TranscriptionJobName']}.{format}"
 | 
			
		||||
            )
 | 
			
		||||
            for format in args["Subtitles"]["Formats"]
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_transcribe
 | 
			
		||||
@ -329,6 +345,11 @@ def test_run_transcription_job_minimal_params():
 | 
			
		||||
    assert (
 | 
			
		||||
        f"https://s3.{region_name}.amazonaws.com/aws-transcribe-{region_name}-prod/"
 | 
			
		||||
    ) in transcription_job["Transcript"]["TranscriptFileUri"]
 | 
			
		||||
    assert transcription_job["Subtitles"] == {
 | 
			
		||||
        "Formats": [],
 | 
			
		||||
        "OutputStartIndex": 0,
 | 
			
		||||
        "SubtitleFileUris": [],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Delete
 | 
			
		||||
    client.delete_transcription_job(TranscriptionJobName=job_name)
 | 
			
		||||
@ -348,7 +369,8 @@ def test_run_transcription_job_s3output_params():
 | 
			
		||||
        "LanguageCode": "en-US",
 | 
			
		||||
        "Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"},
 | 
			
		||||
        "OutputBucketName": "my-output-bucket",
 | 
			
		||||
        "OutputKey": "bucket-key",
 | 
			
		||||
        "OutputKey": "bucket.json.key.json",
 | 
			
		||||
        "Subtitles": {"Formats": ["vtt", "srt"]},
 | 
			
		||||
    }
 | 
			
		||||
    resp = client.start_transcription_job(**args)
 | 
			
		||||
    assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
 | 
			
		||||
@ -379,8 +401,16 @@ def test_run_transcription_job_s3output_params():
 | 
			
		||||
    assert "Transcript" in transcription_job
 | 
			
		||||
    # Check aws hosted bucket
 | 
			
		||||
    assert (
 | 
			
		||||
        "https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket-key/MyJob.json"
 | 
			
		||||
        "https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.json"
 | 
			
		||||
    ) in transcription_job["Transcript"]["TranscriptFileUri"]
 | 
			
		||||
    assert transcription_job["Subtitles"] == {
 | 
			
		||||
        "Formats": args["Subtitles"]["Formats"],
 | 
			
		||||
        "SubtitleFileUris": [
 | 
			
		||||
            f"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.{format}"
 | 
			
		||||
            for format in args["Subtitles"]["Formats"]
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # A new job without an "OutputKey"
 | 
			
		||||
    job_name = "MyJob2"
 | 
			
		||||
    args = {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user