Transcribe: add IdentifyMultipleLanguages and return LanguageCodes (#5918)

This commit is contained in:
Ben van den Berg 2023-02-11 04:50:30 -07:00 committed by GitHub
parent bb1b5f511a
commit 11ccbd0b9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 15 deletions

View File

@ -45,6 +45,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
job_execution_settings,
content_redaction,
identify_language,
identify_multiple_languages,
language_options,
):
ManagedState.__init__(
@ -60,6 +61,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self._region_name = region_name
self.transcription_job_name = transcription_job_name
self.language_code = language_code
self.language_codes = None
self.media_sample_rate_hertz = media_sample_rate_hertz
self.media_format = media_format
self.media = media
@ -82,6 +84,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"RedactionOutput": None,
}
self.identify_language = identify_language
self.identify_multiple_languages = identify_multiple_languages
self.language_options = language_options
self.identified_language_score = (None,)
self._output_bucket_name = output_bucket_name
@ -97,12 +100,14 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"TranscriptionJobName",
"TranscriptionJobStatus",
"LanguageCode",
"LanguageCodes",
"MediaFormat",
"Media",
"Settings",
"StartTime",
"CreationTime",
"IdentifyLanguage",
"IdentifyMultipleLanguages",
"LanguageOptions",
"JobExecutionSettings",
],
@ -110,6 +115,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"TranscriptionJobName",
"TranscriptionJobStatus",
"LanguageCode",
"LanguageCodes",
"MediaSampleRateHertz",
"MediaFormat",
"Media",
@ -119,6 +125,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"CreationTime",
"CompletionTime",
"IdentifyLanguage",
"IdentifyMultipleLanguages",
"LanguageOptions",
"IdentifiedLanguageScore",
],
@ -128,9 +135,11 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"StartTime",
"CompletionTime",
"LanguageCode",
"LanguageCodes",
"TranscriptionJobStatus",
"FailureReason",
"IdentifyLanguage",
"IdentifyMultipleLanguages",
"IdentifiedLanguageScore",
"OutputLocationType",
],
@ -172,12 +181,35 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
)
if self.identify_language:
self.identified_language_score = 0.999645948
# Simply identify first language passed in lanugage_options
# If non is set default to "en-US"
# Simply identify first language passed in language_options
# If none is set, default to "en-US"
if self.language_options is not None and len(self.language_options) > 0:
self.language_code = self.language_options[0]
else:
self.language_code = "en-US"
if self.identify_multiple_languages:
self.identified_language_score = 0.999645948
# Identify first two languages passed in language_options
# If none is set, default to "en-US"
self.language_codes = []
if self.language_options is None or len(self.language_options) == 0:
self.language_codes.append(
{"LanguageCode": "en-US", "DurationInSeconds": 123.0}
)
else:
self.language_codes.append(
{
"LanguageCode": self.language_options[0],
"DurationInSeconds": 123.0,
}
)
if len(self.language_options) > 1:
self.language_codes.append(
{
"LanguageCode": self.language_options[1],
"DurationInSeconds": 321.0,
}
)
elif new_status == "COMPLETED":
self.completion_time = (datetime.now() + timedelta(seconds=10)).strftime(
"%Y-%m-%d %H:%M:%S"
@ -482,6 +514,7 @@ class TranscribeBackend(BaseBackend):
job_execution_settings=kwargs.get("job_execution_settings"),
content_redaction=kwargs.get("content_redaction"),
identify_language=kwargs.get("identify_language"),
identify_multiple_languages=kwargs.get("identify_multiple_languages"),
language_options=kwargs.get("language_options"),
)
self.transcriptions[name] = transcription_job_object

View File

@ -37,6 +37,7 @@ class TranscribeResponse(BaseResponse):
job_execution_settings=self._get_param("JobExecutionSettings"),
content_redaction=self._get_param("ContentRedaction"),
identify_language=self._get_param("IdentifyLanguage"),
identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"),
language_options=self._get_param("LanguageOptions"),
)
return json.dumps(response)

View File

@ -397,11 +397,12 @@ def test_run_transcription_job_s3output_params():
@mock_transcribe
def test_run_transcription_job_identify_language_params():
def test_run_transcription_job_identify_languages_params():
region_name = "us-east-1"
client = boto3.client("transcribe", region_name=region_name)
# IdentifyLanguage
job_name = "MyJob"
args = {
"TranscriptionJobName": job_name,
@ -409,18 +410,64 @@ def test_run_transcription_job_identify_language_params():
"IdentifyLanguage": True,
"LanguageOptions": ["en-US", "en-GB", "es-ES", "de-DE"],
}
resp = client.start_transcription_job(**args)
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
transcription_job = resp["TranscriptionJob"]
transcription_job.should.contain("IdentifyLanguage")
transcription_job.should.contain("LanguageOptions")
client.get_transcription_job(TranscriptionJobName=job_name)
resp = client.get_transcription_job(TranscriptionJobName=job_name)
transcription_job = resp["TranscriptionJob"]
transcription_job.should.contain("LanguageCode")
transcription_job.should.contain("IdentifiedLanguageScore")
transcription_job["LanguageCode"].should.equal("en-US")
transcription_job["IdentifiedLanguageScore"].should.equal(0.999645948)
resp_data = [
client.start_transcription_job(**args), # CREATED
client.get_transcription_job(TranscriptionJobName=job_name), # QUEUED
client.get_transcription_job(TranscriptionJobName=job_name), # IN_PROGRESS
client.list_transcription_jobs(), # IN_PROGRESS
]
for resp in resp_data:
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
if "TranscriptionJob" in resp:
transcription_job = resp["TranscriptionJob"]
elif "TranscriptionJobSummaries" in resp:
transcription_job = resp["TranscriptionJobSummaries"][0]
transcription_job.should.contain("IdentifyLanguage")
transcription_job.should_not.contain("LanguageCodes")
transcription_job.should_not.contain("IdentifyMultipleLanguages")
if "TranscriptionJobStatus" in transcription_job and (
transcription_job["TranscriptionJobStatus"] == "IN_PROGRESS"
or transcription_job["TranscriptionJobStatus"] == "COMPLETED"
):
transcription_job["LanguageCode"].should.equal("en-US")
transcription_job["IdentifiedLanguageScore"].should.equal(0.999645948)
# IdentifyMultipleLanguages
job_name = "MyJob2"
args = {
"TranscriptionJobName": job_name,
"Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"},
"IdentifyMultipleLanguages": True,
"LanguageOptions": ["en-US", "en-GB", "es-ES", "de-DE"],
}
resp_data = [
client.start_transcription_job(**args), # CREATED
client.get_transcription_job(TranscriptionJobName=job_name), # QUEUED
client.get_transcription_job(TranscriptionJobName=job_name), # IN_PROGRESS
client.list_transcription_jobs(), # IN_PROGRESS
]
for resp in resp_data:
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
if "TranscriptionJob" in resp:
transcription_job = resp["TranscriptionJob"]
elif "TranscriptionJobSummaries" in resp:
transcription_job = resp["TranscriptionJobSummaries"][1]
transcription_job.should.contain("IdentifyMultipleLanguages")
transcription_job.should_not.contain("LanguageCode")
transcription_job.should_not.contain("IdentifyLanguage")
if "TranscriptionJobStatus" in transcription_job and (
transcription_job["TranscriptionJobStatus"] == "IN_PROGRESS"
or transcription_job["TranscriptionJobStatus"] == "COMPLETED"
):
transcription_job["LanguageCodes"][0]["LanguageCode"].should.equal("en-US")
transcription_job["LanguageCodes"][0]["DurationInSeconds"].should.equal(
123.0
)
transcription_job["LanguageCodes"][1]["LanguageCode"].should.equal("en-GB")
transcription_job["LanguageCodes"][1]["DurationInSeconds"].should.equal(
321.0
)
transcription_job["IdentifiedLanguageScore"].should.equal(0.999645948)
@mock_transcribe