239 lines
7.7 KiB
Python
239 lines
7.7 KiB
Python
"""ComprehendBackend class with methods for supported APIs."""
|
|
|
|
from typing import Any, Dict, Iterable, List
|
|
|
|
from moto.core.base_backend import BackendDict, BaseBackend
|
|
from moto.core.common_models import BaseModel
|
|
from moto.utilities.tagging_service import TaggingService
|
|
|
|
from .exceptions import (
|
|
DetectPIIValidationException,
|
|
ResourceNotFound,
|
|
TextSizeLimitExceededException,
|
|
)
|
|
|
|
CANNED_DETECT_RESPONSE = [
|
|
{
|
|
"Score": 0.9999890923500061,
|
|
"Type": "NAME",
|
|
"BeginOffset": 50,
|
|
"EndOffset": 58,
|
|
},
|
|
{
|
|
"Score": 0.9999966621398926,
|
|
"Type": "EMAIL",
|
|
"BeginOffset": 230,
|
|
"EndOffset": 259,
|
|
},
|
|
{
|
|
"Score": 0.9999954700469971,
|
|
"Type": "BANK_ACCOUNT_NUMBER",
|
|
"BeginOffset": 334,
|
|
"EndOffset": 349,
|
|
},
|
|
]
|
|
|
|
CANNED_PHRASES_RESPONSE = [
|
|
{
|
|
"Score": 0.9999890923500061,
|
|
"BeginOffset": 50,
|
|
"EndOffset": 58,
|
|
},
|
|
{
|
|
"Score": 0.9999966621398926,
|
|
"BeginOffset": 230,
|
|
"EndOffset": 259,
|
|
},
|
|
{
|
|
"Score": 0.9999954700469971,
|
|
"BeginOffset": 334,
|
|
"EndOffset": 349,
|
|
},
|
|
]
|
|
|
|
CANNED_SENTIMENT_RESPONSE = {
|
|
"Sentiment": "NEUTRAL",
|
|
"SentimentScore": {
|
|
"Positive": 0.008101312443614006,
|
|
"Negative": 0.0002824589901138097,
|
|
"Neutral": 0.9916020035743713,
|
|
"Mixed": 1.4156351426208857e-05,
|
|
},
|
|
}
|
|
|
|
|
|
class EntityRecognizer(BaseModel):
|
|
def __init__(
|
|
self,
|
|
region_name: str,
|
|
account_id: str,
|
|
language_code: str,
|
|
input_data_config: Dict[str, Any],
|
|
data_access_role_arn: str,
|
|
version_name: str,
|
|
recognizer_name: str,
|
|
volume_kms_key_id: str,
|
|
vpc_config: Dict[str, List[str]],
|
|
model_kms_key_id: str,
|
|
model_policy: str,
|
|
):
|
|
self.name = recognizer_name
|
|
self.arn = f"arn:aws:comprehend:{region_name}:{account_id}:entity-recognizer/{recognizer_name}"
|
|
if version_name:
|
|
self.arn += f"/version/{version_name}"
|
|
self.language_code = language_code
|
|
self.input_data_config = input_data_config
|
|
self.data_access_role_arn = data_access_role_arn
|
|
self.version_name = version_name
|
|
self.volume_kms_key_id = volume_kms_key_id
|
|
self.vpc_config = vpc_config
|
|
self.model_kms_key_id = model_kms_key_id
|
|
self.model_policy = model_policy
|
|
self.status = "TRAINED"
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"EntityRecognizerArn": self.arn,
|
|
"LanguageCode": self.language_code,
|
|
"Status": self.status,
|
|
"InputDataConfig": self.input_data_config,
|
|
"DataAccessRoleArn": self.data_access_role_arn,
|
|
"VersionName": self.version_name,
|
|
"VolumeKmsKeyId": self.volume_kms_key_id,
|
|
"VpcConfig": self.vpc_config,
|
|
"ModelKmsKeyId": self.model_kms_key_id,
|
|
"ModelPolicy": self.model_policy,
|
|
}
|
|
|
|
|
|
class ComprehendBackend(BaseBackend):
|
|
"""Implementation of Comprehend APIs."""
|
|
|
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend/client/detect_key_phrases.html
|
|
detect_key_phrases_languages = [
|
|
"ar",
|
|
"hi",
|
|
"ko",
|
|
"zh-TW",
|
|
"ja",
|
|
"zh",
|
|
"de",
|
|
"pt",
|
|
"en",
|
|
"it",
|
|
"fr",
|
|
"es",
|
|
]
|
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend/client/detect_pii_entities.html
|
|
detect_pii_entities_languages = ["en"]
|
|
|
|
def __init__(self, region_name: str, account_id: str):
|
|
super().__init__(region_name, account_id)
|
|
self.recognizers: Dict[str, EntityRecognizer] = dict()
|
|
self.tagger = TaggingService()
|
|
|
|
def list_entity_recognizers(
|
|
self, _filter: Dict[str, Any]
|
|
) -> Iterable[EntityRecognizer]:
|
|
"""
|
|
Pagination is not yet implemented.
|
|
The following filters are not yet implemented: Status, SubmitTimeBefore, SubmitTimeAfter
|
|
"""
|
|
if "RecognizerName" in _filter:
|
|
return [
|
|
entity
|
|
for entity in self.recognizers.values()
|
|
if entity.name == _filter["RecognizerName"]
|
|
]
|
|
return self.recognizers.values()
|
|
|
|
def create_entity_recognizer(
|
|
self,
|
|
recognizer_name: str,
|
|
version_name: str,
|
|
data_access_role_arn: str,
|
|
tags: List[Dict[str, str]],
|
|
input_data_config: Dict[str, Any],
|
|
language_code: str,
|
|
volume_kms_key_id: str,
|
|
vpc_config: Dict[str, List[str]],
|
|
model_kms_key_id: str,
|
|
model_policy: str,
|
|
) -> str:
|
|
"""
|
|
The ClientRequestToken-parameter is not yet implemented
|
|
"""
|
|
recognizer = EntityRecognizer(
|
|
region_name=self.region_name,
|
|
account_id=self.account_id,
|
|
language_code=language_code,
|
|
input_data_config=input_data_config,
|
|
data_access_role_arn=data_access_role_arn,
|
|
version_name=version_name,
|
|
recognizer_name=recognizer_name,
|
|
volume_kms_key_id=volume_kms_key_id,
|
|
vpc_config=vpc_config,
|
|
model_kms_key_id=model_kms_key_id,
|
|
model_policy=model_policy,
|
|
)
|
|
self.recognizers[recognizer.arn] = recognizer
|
|
self.tagger.tag_resource(recognizer.arn, tags)
|
|
return recognizer.arn
|
|
|
|
def describe_entity_recognizer(
|
|
self, entity_recognizer_arn: str
|
|
) -> EntityRecognizer:
|
|
if entity_recognizer_arn not in self.recognizers:
|
|
raise ResourceNotFound
|
|
return self.recognizers[entity_recognizer_arn]
|
|
|
|
def stop_training_entity_recognizer(self, entity_recognizer_arn: str) -> None:
|
|
recognizer = self.describe_entity_recognizer(entity_recognizer_arn)
|
|
if recognizer.status == "TRAINING":
|
|
recognizer.status = "STOP_REQUESTED"
|
|
|
|
def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:
|
|
return self.tagger.list_tags_for_resource(resource_arn)["Tags"]
|
|
|
|
def delete_entity_recognizer(self, entity_recognizer_arn: str) -> None:
|
|
self.recognizers.pop(entity_recognizer_arn, None)
|
|
|
|
def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
|
|
self.tagger.tag_resource(resource_arn, tags)
|
|
|
|
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
|
|
self.tagger.untag_resource_using_names(resource_arn, tag_keys)
|
|
|
|
def detect_pii_entities(self, text: str, language: str) -> List[Dict[str, Any]]:
|
|
if language not in self.detect_pii_entities_languages:
|
|
raise DetectPIIValidationException(
|
|
language, self.detect_pii_entities_languages
|
|
)
|
|
text_size = len(text)
|
|
if text_size > 100000:
|
|
raise TextSizeLimitExceededException(text_size)
|
|
return CANNED_DETECT_RESPONSE
|
|
|
|
def detect_key_phrases(self, text: str, language: str) -> List[Dict[str, Any]]:
|
|
if language not in self.detect_key_phrases_languages:
|
|
raise DetectPIIValidationException(
|
|
language, self.detect_key_phrases_languages
|
|
)
|
|
text_size = len(text)
|
|
if text_size > 100000:
|
|
raise TextSizeLimitExceededException(text_size)
|
|
return CANNED_PHRASES_RESPONSE
|
|
|
|
def detect_sentiment(self, text: str, language: str) -> Dict[str, Any]:
|
|
if language not in self.detect_key_phrases_languages:
|
|
raise DetectPIIValidationException(
|
|
language, self.detect_key_phrases_languages
|
|
)
|
|
text_size = len(text)
|
|
if text_size > 5000:
|
|
raise TextSizeLimitExceededException(text_size)
|
|
return CANNED_SENTIMENT_RESPONSE
|
|
|
|
|
|
comprehend_backends = BackendDict(ComprehendBackend, "comprehend")
|