moto/moto/comprehend/models.py
2024-01-27 19:38:09 +00:00

239 lines
7.7 KiB
Python

"""ComprehendBackend class with methods for supported APIs."""
from typing import Any, Dict, Iterable, List
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.utilities.tagging_service import TaggingService
from .exceptions import (
DetectPIIValidationException,
ResourceNotFound,
TextSizeLimitExceededException,
)
CANNED_DETECT_RESPONSE = [
{
"Score": 0.9999890923500061,
"Type": "NAME",
"BeginOffset": 50,
"EndOffset": 58,
},
{
"Score": 0.9999966621398926,
"Type": "EMAIL",
"BeginOffset": 230,
"EndOffset": 259,
},
{
"Score": 0.9999954700469971,
"Type": "BANK_ACCOUNT_NUMBER",
"BeginOffset": 334,
"EndOffset": 349,
},
]
CANNED_PHRASES_RESPONSE = [
{
"Score": 0.9999890923500061,
"BeginOffset": 50,
"EndOffset": 58,
},
{
"Score": 0.9999966621398926,
"BeginOffset": 230,
"EndOffset": 259,
},
{
"Score": 0.9999954700469971,
"BeginOffset": 334,
"EndOffset": 349,
},
]
CANNED_SENTIMENT_RESPONSE = {
"Sentiment": "NEUTRAL",
"SentimentScore": {
"Positive": 0.008101312443614006,
"Negative": 0.0002824589901138097,
"Neutral": 0.9916020035743713,
"Mixed": 1.4156351426208857e-05,
},
}
class EntityRecognizer(BaseModel):
def __init__(
self,
region_name: str,
account_id: str,
language_code: str,
input_data_config: Dict[str, Any],
data_access_role_arn: str,
version_name: str,
recognizer_name: str,
volume_kms_key_id: str,
vpc_config: Dict[str, List[str]],
model_kms_key_id: str,
model_policy: str,
):
self.name = recognizer_name
self.arn = f"arn:aws:comprehend:{region_name}:{account_id}:entity-recognizer/{recognizer_name}"
if version_name:
self.arn += f"/version/{version_name}"
self.language_code = language_code
self.input_data_config = input_data_config
self.data_access_role_arn = data_access_role_arn
self.version_name = version_name
self.volume_kms_key_id = volume_kms_key_id
self.vpc_config = vpc_config
self.model_kms_key_id = model_kms_key_id
self.model_policy = model_policy
self.status = "TRAINED"
def to_dict(self) -> Dict[str, Any]:
return {
"EntityRecognizerArn": self.arn,
"LanguageCode": self.language_code,
"Status": self.status,
"InputDataConfig": self.input_data_config,
"DataAccessRoleArn": self.data_access_role_arn,
"VersionName": self.version_name,
"VolumeKmsKeyId": self.volume_kms_key_id,
"VpcConfig": self.vpc_config,
"ModelKmsKeyId": self.model_kms_key_id,
"ModelPolicy": self.model_policy,
}
class ComprehendBackend(BaseBackend):
"""Implementation of Comprehend APIs."""
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend/client/detect_key_phrases.html
detect_key_phrases_languages = [
"ar",
"hi",
"ko",
"zh-TW",
"ja",
"zh",
"de",
"pt",
"en",
"it",
"fr",
"es",
]
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend/client/detect_pii_entities.html
detect_pii_entities_languages = ["en"]
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.recognizers: Dict[str, EntityRecognizer] = dict()
self.tagger = TaggingService()
def list_entity_recognizers(
self, _filter: Dict[str, Any]
) -> Iterable[EntityRecognizer]:
"""
Pagination is not yet implemented.
The following filters are not yet implemented: Status, SubmitTimeBefore, SubmitTimeAfter
"""
if "RecognizerName" in _filter:
return [
entity
for entity in self.recognizers.values()
if entity.name == _filter["RecognizerName"]
]
return self.recognizers.values()
def create_entity_recognizer(
self,
recognizer_name: str,
version_name: str,
data_access_role_arn: str,
tags: List[Dict[str, str]],
input_data_config: Dict[str, Any],
language_code: str,
volume_kms_key_id: str,
vpc_config: Dict[str, List[str]],
model_kms_key_id: str,
model_policy: str,
) -> str:
"""
The ClientRequestToken-parameter is not yet implemented
"""
recognizer = EntityRecognizer(
region_name=self.region_name,
account_id=self.account_id,
language_code=language_code,
input_data_config=input_data_config,
data_access_role_arn=data_access_role_arn,
version_name=version_name,
recognizer_name=recognizer_name,
volume_kms_key_id=volume_kms_key_id,
vpc_config=vpc_config,
model_kms_key_id=model_kms_key_id,
model_policy=model_policy,
)
self.recognizers[recognizer.arn] = recognizer
self.tagger.tag_resource(recognizer.arn, tags)
return recognizer.arn
def describe_entity_recognizer(
self, entity_recognizer_arn: str
) -> EntityRecognizer:
if entity_recognizer_arn not in self.recognizers:
raise ResourceNotFound
return self.recognizers[entity_recognizer_arn]
def stop_training_entity_recognizer(self, entity_recognizer_arn: str) -> None:
recognizer = self.describe_entity_recognizer(entity_recognizer_arn)
if recognizer.status == "TRAINING":
recognizer.status = "STOP_REQUESTED"
def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:
return self.tagger.list_tags_for_resource(resource_arn)["Tags"]
def delete_entity_recognizer(self, entity_recognizer_arn: str) -> None:
self.recognizers.pop(entity_recognizer_arn, None)
def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self.tagger.tag_resource(resource_arn, tags)
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def detect_pii_entities(self, text: str, language: str) -> List[Dict[str, Any]]:
if language not in self.detect_pii_entities_languages:
raise DetectPIIValidationException(
language, self.detect_pii_entities_languages
)
text_size = len(text)
if text_size > 100000:
raise TextSizeLimitExceededException(text_size)
return CANNED_DETECT_RESPONSE
def detect_key_phrases(self, text: str, language: str) -> List[Dict[str, Any]]:
if language not in self.detect_key_phrases_languages:
raise DetectPIIValidationException(
language, self.detect_key_phrases_languages
)
text_size = len(text)
if text_size > 100000:
raise TextSizeLimitExceededException(text_size)
return CANNED_PHRASES_RESPONSE
def detect_sentiment(self, text: str, language: str) -> Dict[str, Any]:
if language not in self.detect_key_phrases_languages:
raise DetectPIIValidationException(
language, self.detect_key_phrases_languages
)
text_size = len(text)
if text_size > 5000:
raise TextSizeLimitExceededException(text_size)
return CANNED_SENTIMENT_RESPONSE
comprehend_backends = BackendDict(ComprehendBackend, "comprehend")