Transcribe: add IdentifyMultipleLanguages and return LanguageCodes (#5918)

This commit is contained in:
Ben van den Berg 2023-02-11 04:50:30 -07:00 committed by GitHub
parent bb1b5f511a
commit 11ccbd0b9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 15 deletions

View File

@ -45,6 +45,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
job_execution_settings, job_execution_settings,
content_redaction, content_redaction,
identify_language, identify_language,
identify_multiple_languages,
language_options, language_options,
): ):
ManagedState.__init__( ManagedState.__init__(
@ -60,6 +61,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
self._region_name = region_name self._region_name = region_name
self.transcription_job_name = transcription_job_name self.transcription_job_name = transcription_job_name
self.language_code = language_code self.language_code = language_code
self.language_codes = None
self.media_sample_rate_hertz = media_sample_rate_hertz self.media_sample_rate_hertz = media_sample_rate_hertz
self.media_format = media_format self.media_format = media_format
self.media = media self.media = media
@ -82,6 +84,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"RedactionOutput": None, "RedactionOutput": None,
} }
self.identify_language = identify_language self.identify_language = identify_language
self.identify_multiple_languages = identify_multiple_languages
self.language_options = language_options self.language_options = language_options
self.identified_language_score = (None,) self.identified_language_score = (None,)
self._output_bucket_name = output_bucket_name self._output_bucket_name = output_bucket_name
@ -97,12 +100,14 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"TranscriptionJobName", "TranscriptionJobName",
"TranscriptionJobStatus", "TranscriptionJobStatus",
"LanguageCode", "LanguageCode",
"LanguageCodes",
"MediaFormat", "MediaFormat",
"Media", "Media",
"Settings", "Settings",
"StartTime", "StartTime",
"CreationTime", "CreationTime",
"IdentifyLanguage", "IdentifyLanguage",
"IdentifyMultipleLanguages",
"LanguageOptions", "LanguageOptions",
"JobExecutionSettings", "JobExecutionSettings",
], ],
@ -110,6 +115,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"TranscriptionJobName", "TranscriptionJobName",
"TranscriptionJobStatus", "TranscriptionJobStatus",
"LanguageCode", "LanguageCode",
"LanguageCodes",
"MediaSampleRateHertz", "MediaSampleRateHertz",
"MediaFormat", "MediaFormat",
"Media", "Media",
@ -119,6 +125,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"CreationTime", "CreationTime",
"CompletionTime", "CompletionTime",
"IdentifyLanguage", "IdentifyLanguage",
"IdentifyMultipleLanguages",
"LanguageOptions", "LanguageOptions",
"IdentifiedLanguageScore", "IdentifiedLanguageScore",
], ],
@ -128,9 +135,11 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
"StartTime", "StartTime",
"CompletionTime", "CompletionTime",
"LanguageCode", "LanguageCode",
"LanguageCodes",
"TranscriptionJobStatus", "TranscriptionJobStatus",
"FailureReason", "FailureReason",
"IdentifyLanguage", "IdentifyLanguage",
"IdentifyMultipleLanguages",
"IdentifiedLanguageScore", "IdentifiedLanguageScore",
"OutputLocationType", "OutputLocationType",
], ],
@ -172,12 +181,35 @@ class FakeTranscriptionJob(BaseObject, ManagedState):
) )
if self.identify_language: if self.identify_language:
self.identified_language_score = 0.999645948 self.identified_language_score = 0.999645948
# Simply identify first language passed in lanugage_options # Simply identify first language passed in language_options
# If non is set default to "en-US" # If none is set, default to "en-US"
if self.language_options is not None and len(self.language_options) > 0: if self.language_options is not None and len(self.language_options) > 0:
self.language_code = self.language_options[0] self.language_code = self.language_options[0]
else: else:
self.language_code = "en-US" self.language_code = "en-US"
if self.identify_multiple_languages:
self.identified_language_score = 0.999645948
# Identify first two languages passed in language_options
# If none is set, default to "en-US"
self.language_codes = []
if self.language_options is None or len(self.language_options) == 0:
self.language_codes.append(
{"LanguageCode": "en-US", "DurationInSeconds": 123.0}
)
else:
self.language_codes.append(
{
"LanguageCode": self.language_options[0],
"DurationInSeconds": 123.0,
}
)
if len(self.language_options) > 1:
self.language_codes.append(
{
"LanguageCode": self.language_options[1],
"DurationInSeconds": 321.0,
}
)
elif new_status == "COMPLETED": elif new_status == "COMPLETED":
self.completion_time = (datetime.now() + timedelta(seconds=10)).strftime( self.completion_time = (datetime.now() + timedelta(seconds=10)).strftime(
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
@ -482,6 +514,7 @@ class TranscribeBackend(BaseBackend):
job_execution_settings=kwargs.get("job_execution_settings"), job_execution_settings=kwargs.get("job_execution_settings"),
content_redaction=kwargs.get("content_redaction"), content_redaction=kwargs.get("content_redaction"),
identify_language=kwargs.get("identify_language"), identify_language=kwargs.get("identify_language"),
identify_multiple_languages=kwargs.get("identify_multiple_languages"),
language_options=kwargs.get("language_options"), language_options=kwargs.get("language_options"),
) )
self.transcriptions[name] = transcription_job_object self.transcriptions[name] = transcription_job_object

View File

@ -37,6 +37,7 @@ class TranscribeResponse(BaseResponse):
job_execution_settings=self._get_param("JobExecutionSettings"), job_execution_settings=self._get_param("JobExecutionSettings"),
content_redaction=self._get_param("ContentRedaction"), content_redaction=self._get_param("ContentRedaction"),
identify_language=self._get_param("IdentifyLanguage"), identify_language=self._get_param("IdentifyLanguage"),
identify_multiple_languages=self._get_param("IdentifyMultipleLanguages"),
language_options=self._get_param("LanguageOptions"), language_options=self._get_param("LanguageOptions"),
) )
return json.dumps(response) return json.dumps(response)

View File

@ -397,11 +397,12 @@ def test_run_transcription_job_s3output_params():
@mock_transcribe @mock_transcribe
def test_run_transcription_job_identify_language_params(): def test_run_transcription_job_identify_languages_params():
region_name = "us-east-1" region_name = "us-east-1"
client = boto3.client("transcribe", region_name=region_name) client = boto3.client("transcribe", region_name=region_name)
# IdentifyLanguage
job_name = "MyJob" job_name = "MyJob"
args = { args = {
"TranscriptionJobName": job_name, "TranscriptionJobName": job_name,
@ -409,18 +410,64 @@ def test_run_transcription_job_identify_language_params():
"IdentifyLanguage": True, "IdentifyLanguage": True,
"LanguageOptions": ["en-US", "en-GB", "es-ES", "de-DE"], "LanguageOptions": ["en-US", "en-GB", "es-ES", "de-DE"],
} }
resp = client.start_transcription_job(**args) resp_data = [
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) client.start_transcription_job(**args), # CREATED
transcription_job = resp["TranscriptionJob"] client.get_transcription_job(TranscriptionJobName=job_name), # QUEUED
transcription_job.should.contain("IdentifyLanguage") client.get_transcription_job(TranscriptionJobName=job_name), # IN_PROGRESS
transcription_job.should.contain("LanguageOptions") client.list_transcription_jobs(), # IN_PROGRESS
client.get_transcription_job(TranscriptionJobName=job_name) ]
resp = client.get_transcription_job(TranscriptionJobName=job_name) for resp in resp_data:
transcription_job = resp["TranscriptionJob"] resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
transcription_job.should.contain("LanguageCode") if "TranscriptionJob" in resp:
transcription_job.should.contain("IdentifiedLanguageScore") transcription_job = resp["TranscriptionJob"]
transcription_job["LanguageCode"].should.equal("en-US") elif "TranscriptionJobSummaries" in resp:
transcription_job["IdentifiedLanguageScore"].should.equal(0.999645948) transcription_job = resp["TranscriptionJobSummaries"][0]
transcription_job.should.contain("IdentifyLanguage")
transcription_job.should_not.contain("LanguageCodes")
transcription_job.should_not.contain("IdentifyMultipleLanguages")
if "TranscriptionJobStatus" in transcription_job and (
transcription_job["TranscriptionJobStatus"] == "IN_PROGRESS"
or transcription_job["TranscriptionJobStatus"] == "COMPLETED"
):
transcription_job["LanguageCode"].should.equal("en-US")
transcription_job["IdentifiedLanguageScore"].should.equal(0.999645948)
# IdentifyMultipleLanguages
job_name = "MyJob2"
args = {
"TranscriptionJobName": job_name,
"Media": {"MediaFileUri": "s3://my-bucket/my-media-file.wav"},
"IdentifyMultipleLanguages": True,
"LanguageOptions": ["en-US", "en-GB", "es-ES", "de-DE"],
}
resp_data = [
client.start_transcription_job(**args), # CREATED
client.get_transcription_job(TranscriptionJobName=job_name), # QUEUED
client.get_transcription_job(TranscriptionJobName=job_name), # IN_PROGRESS
client.list_transcription_jobs(), # IN_PROGRESS
]
for resp in resp_data:
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
if "TranscriptionJob" in resp:
transcription_job = resp["TranscriptionJob"]
elif "TranscriptionJobSummaries" in resp:
transcription_job = resp["TranscriptionJobSummaries"][1]
transcription_job.should.contain("IdentifyMultipleLanguages")
transcription_job.should_not.contain("LanguageCode")
transcription_job.should_not.contain("IdentifyLanguage")
if "TranscriptionJobStatus" in transcription_job and (
transcription_job["TranscriptionJobStatus"] == "IN_PROGRESS"
or transcription_job["TranscriptionJobStatus"] == "COMPLETED"
):
transcription_job["LanguageCodes"][0]["LanguageCode"].should.equal("en-US")
transcription_job["LanguageCodes"][0]["DurationInSeconds"].should.equal(
123.0
)
transcription_job["LanguageCodes"][1]["LanguageCode"].should.equal("en-GB")
transcription_job["LanguageCodes"][1]["DurationInSeconds"].should.equal(
321.0
)
transcription_job["IdentifiedLanguageScore"].should.equal(0.999645948)
@mock_transcribe @mock_transcribe