Add subtitle support for transcribe
(#7028)
This commit is contained in:
parent
63e869d717
commit
86f1d53f54
@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from moto.core import BaseBackend, BackendDict, BaseModel
|
from moto.core import BaseBackend, BackendDict, BaseModel
|
||||||
@ -48,6 +49,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
|||||||
identify_language: Optional[bool],
|
identify_language: Optional[bool],
|
||||||
identify_multiple_languages: Optional[bool],
|
identify_multiple_languages: Optional[bool],
|
||||||
language_options: Optional[List[str]],
|
language_options: Optional[List[str]],
|
||||||
|
subtitles: Optional[Dict[str, Any]],
|
||||||
):
|
):
|
||||||
ManagedState.__init__(
|
ManagedState.__init__(
|
||||||
self,
|
self,
|
||||||
@ -95,6 +97,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
|||||||
self.output_location_type = (
|
self.output_location_type = (
|
||||||
"CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
|
"CUSTOMER_BUCKET" if self._output_bucket_name else "SERVICE_BUCKET"
|
||||||
)
|
)
|
||||||
|
self.subtitles = subtitles or {"Formats": [], "OutputStartIndex": 0}
|
||||||
|
|
||||||
def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
|
def response_object(self, response_type: str) -> Dict[str, Any]: # type: ignore
|
||||||
response_field_dict = {
|
response_field_dict = {
|
||||||
@ -112,6 +115,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
|||||||
"IdentifyMultipleLanguages",
|
"IdentifyMultipleLanguages",
|
||||||
"LanguageOptions",
|
"LanguageOptions",
|
||||||
"JobExecutionSettings",
|
"JobExecutionSettings",
|
||||||
|
"Subtitles",
|
||||||
],
|
],
|
||||||
"GET": [
|
"GET": [
|
||||||
"TranscriptionJobName",
|
"TranscriptionJobName",
|
||||||
@ -130,6 +134,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
|||||||
"IdentifyMultipleLanguages",
|
"IdentifyMultipleLanguages",
|
||||||
"LanguageOptions",
|
"LanguageOptions",
|
||||||
"IdentifiedLanguageScore",
|
"IdentifiedLanguageScore",
|
||||||
|
"Subtitles",
|
||||||
],
|
],
|
||||||
"LIST": [
|
"LIST": [
|
||||||
"TranscriptionJobName",
|
"TranscriptionJobName",
|
||||||
@ -217,15 +222,28 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
|
|||||||
"%Y-%m-%d %H:%M:%S"
|
"%Y-%m-%d %H:%M:%S"
|
||||||
)
|
)
|
||||||
if self._output_bucket_name:
|
if self._output_bucket_name:
|
||||||
transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/{self._output_bucket_name}/"
|
remove_json_extension = re.compile("\\.json$")
|
||||||
if self.output_key is not None:
|
transcript_file_prefix = (
|
||||||
transcript_file_uri += f"{self.output_key}/"
|
f"https://s3.{self._region_name}.amazonaws.com/"
|
||||||
transcript_file_uri += f"{self.transcription_job_name}.json"
|
f"{self._output_bucket_name}/"
|
||||||
|
f"{remove_json_extension.sub('', self.output_key or self.transcription_job_name)}"
|
||||||
|
)
|
||||||
self.output_location_type = "CUSTOMER_BUCKET"
|
self.output_location_type = "CUSTOMER_BUCKET"
|
||||||
else:
|
else:
|
||||||
transcript_file_uri = f"https://s3.{self._region_name}.amazonaws.com/aws-transcribe-{self._region_name}-prod/{self._account_id}/{self.transcription_job_name}/{mock_random.uuid4()}/asrOutput.json"
|
transcript_file_prefix = (
|
||||||
|
f"https://s3.{self._region_name}.amazonaws.com/"
|
||||||
|
f"aws-transcribe-{self._region_name}-prod/"
|
||||||
|
f"{self._account_id}/"
|
||||||
|
f"{self.transcription_job_name}/"
|
||||||
|
f"{mock_random.uuid4()}/"
|
||||||
|
"asrOutput"
|
||||||
|
)
|
||||||
self.output_location_type = "SERVICE_BUCKET"
|
self.output_location_type = "SERVICE_BUCKET"
|
||||||
self.transcript = {"TranscriptFileUri": transcript_file_uri}
|
self.transcript = {"TranscriptFileUri": f"{transcript_file_prefix}.json"}
|
||||||
|
self.subtitles["SubtitleFileUris"] = [
|
||||||
|
f"{transcript_file_prefix}.{format}"
|
||||||
|
for format in self.subtitles["Formats"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FakeVocabulary(BaseObject, ManagedState):
|
class FakeVocabulary(BaseObject, ManagedState):
|
||||||
@ -504,6 +522,7 @@ class TranscribeBackend(BaseBackend):
|
|||||||
identify_language: Optional[bool],
|
identify_language: Optional[bool],
|
||||||
identify_multiple_languages: Optional[bool],
|
identify_multiple_languages: Optional[bool],
|
||||||
language_options: Optional[List[str]],
|
language_options: Optional[List[str]],
|
||||||
|
subtitles: Optional[Dict[str, Any]],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if transcription_job_name in self.transcriptions:
|
if transcription_job_name in self.transcriptions:
|
||||||
raise ConflictException(
|
raise ConflictException(
|
||||||
@ -535,6 +554,7 @@ class TranscribeBackend(BaseBackend):
|
|||||||
identify_language=identify_language,
|
identify_language=identify_language,
|
||||||
identify_multiple_languages=identify_multiple_languages,
|
identify_multiple_languages=identify_multiple_languages,
|
||||||
language_options=language_options,
|
language_options=language_options,
|
||||||
|
subtitles=subtitles,
|
||||||
)
|
)
|
||||||
self.transcriptions[transcription_job_name] = transcription_job_object
|
self.transcriptions[transcription_job_name] = transcription_job_object
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ class TranscribeResponse(BaseResponse):
|
|||||||
identify_language=self._get_param("IdentifyLanguage"),
|
identify_language=self._get_param("IdentifyLanguage"),
|
||||||
identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"),
|
identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"),
|
||||||
language_options=self._get_param("LanguageOptions"),
|
language_options=self._get_param("LanguageOptions"),
|
||||||
|
subtitles=self._get_param("Subtitles"),
|
||||||
)
|
)
|
||||||
return json.dumps(response)
|
return json.dumps(response)
|
||||||
|
|
||||||
|
@ -209,6 +209,10 @@ def test_run_transcription_job_all_params():
|
|||||||
"MaxAlternatives": 6,
|
"MaxAlternatives": 6,
|
||||||
"VocabularyName": vocabulary_name,
|
"VocabularyName": vocabulary_name,
|
||||||
},
|
},
|
||||||
|
"Subtitles": {
|
||||||
|
"Formats": ["srt", "vtt"],
|
||||||
|
"OutputStartIndex": 1,
|
||||||
|
},
|
||||||
# Missing `ContentRedaction`, `JobExecutionSettings`,
|
# Missing `ContentRedaction`, `JobExecutionSettings`,
|
||||||
# `VocabularyFilterName`, `LanguageModel`
|
# `VocabularyFilterName`, `LanguageModel`
|
||||||
}
|
}
|
||||||
@ -269,6 +273,18 @@ def test_run_transcription_job_all_params():
|
|||||||
f"/{args['TranscriptionJobName']}.json"
|
f"/{args['TranscriptionJobName']}.json"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
assert transcription_job["Subtitles"] == {
|
||||||
|
"Formats": args["Subtitles"]["Formats"],
|
||||||
|
"OutputStartIndex": 1,
|
||||||
|
"SubtitleFileUris": [
|
||||||
|
(
|
||||||
|
f"https://s3.{region_name}.amazonaws.com"
|
||||||
|
f"/{args['OutputBucketName']}"
|
||||||
|
f"/{args['TranscriptionJobName']}.{format}"
|
||||||
|
)
|
||||||
|
for format in args["Subtitles"]["Formats"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@mock_transcribe
|
@mock_transcribe
|
||||||
@ -329,6 +345,11 @@ def test_run_transcription_job_minimal_params():
|
|||||||
assert (
|
assert (
|
||||||
f"https://s3.{region_name}.amazonaws.com/aws-transcribe-{region_name}-prod/"
|
f"https://s3.{region_name}.amazonaws.com/aws-transcribe-{region_name}-prod/"
|
||||||
) in transcription_job["Transcript"]["TranscriptFileUri"]
|
) in transcription_job["Transcript"]["TranscriptFileUri"]
|
||||||
|
assert transcription_job["Subtitles"] == {
|
||||||
|
"Formats": [],
|
||||||
|
"OutputStartIndex": 0,
|
||||||
|
"SubtitleFileUris": [],
|
||||||
|
}
|
||||||
|
|
||||||
# Delete
|
# Delete
|
||||||
client.delete_transcription_job(TranscriptionJobName=job_name)
|
client.delete_transcription_job(TranscriptionJobName=job_name)
|
||||||
@ -348,7 +369,8 @@ def test_run_transcription_job_s3output_params():
|
|||||||
"LanguageCode": "en-US",
|
"LanguageCode": "en-US",
|
||||||
"Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"},
|
"Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"},
|
||||||
"OutputBucketName": "my-output-bucket",
|
"OutputBucketName": "my-output-bucket",
|
||||||
"OutputKey": "bucket-key",
|
"OutputKey": "bucket.json.key.json",
|
||||||
|
"Subtitles": {"Formats": ["vtt", "srt"]},
|
||||||
}
|
}
|
||||||
resp = client.start_transcription_job(**args)
|
resp = client.start_transcription_job(**args)
|
||||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
@ -379,8 +401,16 @@ def test_run_transcription_job_s3output_params():
|
|||||||
assert "Transcript" in transcription_job
|
assert "Transcript" in transcription_job
|
||||||
# Check aws hosted bucket
|
# Check aws hosted bucket
|
||||||
assert (
|
assert (
|
||||||
"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket-key/MyJob.json"
|
"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.json"
|
||||||
) in transcription_job["Transcript"]["TranscriptFileUri"]
|
) in transcription_job["Transcript"]["TranscriptFileUri"]
|
||||||
|
assert transcription_job["Subtitles"] == {
|
||||||
|
"Formats": args["Subtitles"]["Formats"],
|
||||||
|
"SubtitleFileUris": [
|
||||||
|
f"https://s3.us-east-1.amazonaws.com/my-output-bucket/bucket.json.key.{format}"
|
||||||
|
for format in args["Subtitles"]["Formats"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
# A new job without an "OutputKey"
|
# A new job without an "OutputKey"
|
||||||
job_name = "MyJob2"
|
job_name = "MyJob2"
|
||||||
args = {
|
args = {
|
||||||
|
Loading…
Reference in New Issue
Block a user