From 8d91d09b42a3ccb3f622d9d8e9e79976408f6d92 Mon Sep 17 00:00:00 2001 From: Bogdan Girman Date: Sun, 21 Jan 2024 19:03:29 +0100 Subject: [PATCH] Add SageMaker Feature Group (#7227) --- moto/sagemaker/models.py | 93 +++++++++++++++++++ moto/sagemaker/responses.py | 20 ++++ .../test_sagemaker_feature_groups.py | 85 +++++++++++++++++ 3 files changed, 198 insertions(+) create mode 100644 tests/test_sagemaker/test_sagemaker_feature_groups.py diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 06f5896cc..c284c65ec 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -980,6 +980,60 @@ class ModelPackageGroup(BaseObject): return {k: v for k, v in response_object.items() if k in response_values} +class FeatureGroup(BaseObject): + def __init__( + self, + region_name: str, + account_id: str, + feature_group_name: str, + record_identifier_feature_name: str, + event_time_feature_name: str, + feature_definitions: List[Dict[str, str]], + offline_store_config: Dict[str, Any], + role_arn: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + self.feature_group_name = feature_group_name + self.record_identifier_feature_name = record_identifier_feature_name + self.event_time_feature_name = event_time_feature_name + self.feature_definitions = feature_definitions + + table_name = ( + f"{feature_group_name.replace('-','_')}_{int(datetime.now().timestamp())}" + ) + offline_store_config["DataCatalogConfig"] = { + "TableName": table_name, + "Catalog": "AwsDataCatalog", + "Database": "sagemaker_featurestore", + } + + self.offline_store_config = offline_store_config + self.role_arn = role_arn + + self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.feature_group_arn = arn_formatter( + region_name=region_name, + account_id=account_id, + _type="feature-group", + _id=f"{self.feature_group_name.lower()}", + ) + self.tags = tags + + def describe(self) -> Dict[str, Any]: + return { + "FeatureGroupArn": self.feature_group_arn, + "FeatureGroupName": self.feature_group_name, + "RecordIdentifierFeatureName": self.record_identifier_feature_name, + "EventTimeFeatureName": self.event_time_feature_name, + "FeatureDefinitions": self.feature_definitions, + "CreationTime": self.creation_time, + "OfflineStoreConfig": self.offline_store_config, + "RoleArn": self.role_arn, + "ThroughputConfig": {"ThroughputMode": "OnDemand"}, + "FeatureGroupStatus": "Created", + } + + class ModelPackage(BaseObject): def __init__( self, @@ -1768,6 +1822,7 @@ class SageMakerModelBackend(BaseBackend): self.model_package_groups: Dict[str, ModelPackageGroup] = {} self.model_packages: Dict[str, ModelPackage] = {} self.model_package_name_mapping: Dict[str, str] = {} + self.feature_groups: Dict[str, FeatureGroup] = {} @staticmethod def default_vpc_endpoint_service( @@ -3464,6 +3519,44 @@ class SageMakerModelBackend(BaseBackend): self.model_packages[model_package.model_package_arn] = model_package return model_package.model_package_arn + def create_feature_group( + self, + feature_group_name: str, + record_identifier_feature_name: str, + event_time_feature_name: str, + feature_definitions: List[Dict[str, str]], + offline_store_config: Dict[str, Any], + role_arn: str, + tags: Any, + ) -> str: + feature_group = FeatureGroup( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_identifier_feature_name, + event_time_feature_name=event_time_feature_name, + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=role_arn, + region_name=self.region_name, + account_id=self.account_id, + tags=tags, + ) + self.feature_groups[feature_group.feature_group_arn] = feature_group + return feature_group.feature_group_arn + + def describe_feature_group( + self, + feature_group_name: str, + ) -> Dict[str, Any]: + feature_group_arn = arn_formatter( + region_name=self.region_name, + account_id=self.account_id, + _type="feature-group", + _id=f"{feature_group_name.lower()}", + ) + + feature_group = self.feature_groups[feature_group_arn] + return feature_group.describe() + class FakeExperiment(BaseObject): def __init__( diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index f260130d1..fa63920b0 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -959,3 +959,23 @@ class SageMakerResponse(BaseResponse): tags=tags, ) return json.dumps(dict(ModelPackageGroupArn=model_package_group_arn)) + + def create_feature_group(self) -> str: + feature_group_arn = self.sagemaker_backend.create_feature_group( + feature_group_name=self._get_param("FeatureGroupName"), + record_identifier_feature_name=self._get_param( + "RecordIdentifierFeatureName" + ), + event_time_feature_name=self._get_param("EventTimeFeatureName"), + feature_definitions=self._get_param("FeatureDefinitions"), + offline_store_config=self._get_param("OfflineStoreConfig"), + role_arn=self._get_param("RoleArn"), + tags=self._get_param("Tags"), + ) + return json.dumps(dict(FeatureGroupArn=feature_group_arn)) + + def describe_feature_group(self) -> str: + resp = self.sagemaker_backend.describe_feature_group( + feature_group_name=self._get_param("FeatureGroupName"), + ) + return json.dumps(resp) diff --git a/tests/test_sagemaker/test_sagemaker_feature_groups.py b/tests/test_sagemaker/test_sagemaker_feature_groups.py new file mode 100644 index 000000000..c3f4a1475 --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_feature_groups.py @@ -0,0 +1,85 @@ +"""Unit tests for sagemaker-supported APIs.""" +import re +from datetime import datetime + +import boto3 + +from moto import mock_sagemaker + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +@mock_sagemaker +def test_create_feature_group(): + client = boto3.client("sagemaker", region_name="us-east-2") + resp = client.create_feature_group( + FeatureGroupName="some-feature-group-name", + RecordIdentifierFeatureName="some_record_identifier", + EventTimeFeatureName="EventTime", + FeatureDefinitions=[ + {"FeatureName": "some_feature", "FeatureType": "String"}, + {"FeatureName": "EventTime", "FeatureType": "Fractional"}, + {"FeatureName": "some_record_identifier", "FeatureType": "String"}, + ], + RoleArn="arn:aws:iam::123456789012:role/AWSFeatureStoreAccess", + OfflineStoreConfig={ + "DisableGlueTableCreation": False, + "S3StorageConfig": {"S3Uri": "s3://mybucket"}, + }, + ) + + assert ( + resp["FeatureGroupArn"] + == "arn:aws:sagemaker:us-east-2:123456789012:feature-group/some-feature-group-name" + ) + + +@mock_sagemaker +def test_describe_feature_group(): + client = boto3.client("sagemaker", region_name="us-east-2") + feature_group_name = "some-feature-group-name" + record_identifier_feature_name = "some_record_identifier" + event_time_feature_name = "EventTime" + role_arn = "arn:aws:iam::123456789012:role/AWSFeatureStoreAccess" + feature_definitions = [ + {"FeatureName": "some_feature", "FeatureType": "String"}, + {"FeatureName": event_time_feature_name, "FeatureType": "Fractional"}, + {"FeatureName": record_identifier_feature_name, "FeatureType": "String"}, + ] + client.create_feature_group( + FeatureGroupName=feature_group_name, + RecordIdentifierFeatureName=record_identifier_feature_name, + EventTimeFeatureName=event_time_feature_name, + FeatureDefinitions=feature_definitions, + RoleArn=role_arn, + OfflineStoreConfig={ + "DisableGlueTableCreation": False, + "S3StorageConfig": {"S3Uri": "s3://mybucket"}, + }, + ) + resp = client.describe_feature_group(FeatureGroupName=feature_group_name) + + assert resp["FeatureGroupName"] == feature_group_name + assert ( + resp["FeatureGroupArn"] + == "arn:aws:sagemaker:us-east-2:123456789012:feature-group/some-feature-group-name" + ) + assert resp["RecordIdentifierFeatureName"] == record_identifier_feature_name + assert resp["EventTimeFeatureName"] == event_time_feature_name + assert resp["FeatureDefinitions"] == feature_definitions + assert resp["RoleArn"] == role_arn + assert re.match( + r"^some_feature_group_name_[0-9]+$", + resp["OfflineStoreConfig"]["DataCatalogConfig"]["TableName"], + ) + assert ( + resp["OfflineStoreConfig"]["DataCatalogConfig"]["Catalog"] == "AwsDataCatalog" + ) + assert ( + resp["OfflineStoreConfig"]["DataCatalogConfig"]["Database"] + == "sagemaker_featurestore" + ) + assert resp["OfflineStoreConfig"]["S3StorageConfig"]["S3Uri"] == "s3://mybucket" + assert isinstance(resp["CreationTime"], datetime) + assert resp["FeatureGroupStatus"] == "Created"