Add SageMaker Feature Group (#7227)

This commit is contained in:
Bogdan Girman 2024-01-21 19:03:29 +01:00 committed by GitHub
parent 792956e959
commit 8d91d09b42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 198 additions and 0 deletions

View File

@ -980,6 +980,60 @@ class ModelPackageGroup(BaseObject):
return {k: v for k, v in response_object.items() if k in response_values} return {k: v for k, v in response_object.items() if k in response_values}
class FeatureGroup(BaseObject):
def __init__(
self,
region_name: str,
account_id: str,
feature_group_name: str,
record_identifier_feature_name: str,
event_time_feature_name: str,
feature_definitions: List[Dict[str, str]],
offline_store_config: Dict[str, Any],
role_arn: str,
tags: Optional[List[Dict[str, str]]] = None,
) -> None:
self.feature_group_name = feature_group_name
self.record_identifier_feature_name = record_identifier_feature_name
self.event_time_feature_name = event_time_feature_name
self.feature_definitions = feature_definitions
table_name = (
f"{feature_group_name.replace('-','_')}_{int(datetime.now().timestamp())}"
)
offline_store_config["DataCatalogConfig"] = {
"TableName": table_name,
"Catalog": "AwsDataCatalog",
"Database": "sagemaker_featurestore",
}
self.offline_store_config = offline_store_config
self.role_arn = role_arn
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.feature_group_arn = arn_formatter(
region_name=region_name,
account_id=account_id,
_type="feature-group",
_id=f"{self.feature_group_name.lower()}",
)
self.tags = tags
def describe(self) -> Dict[str, Any]:
return {
"FeatureGroupArn": self.feature_group_arn,
"FeatureGroupName": self.feature_group_name,
"RecordIdentifierFeatureName": self.record_identifier_feature_name,
"EventTimeFeatureName": self.event_time_feature_name,
"FeatureDefinitions": self.feature_definitions,
"CreationTime": self.creation_time,
"OfflineStoreConfig": self.offline_store_config,
"RoleArn": self.role_arn,
"ThroughputConfig": {"ThroughputMode": "OnDemand"},
"FeatureGroupStatus": "Created",
}
class ModelPackage(BaseObject): class ModelPackage(BaseObject):
def __init__( def __init__(
self, self,
@ -1768,6 +1822,7 @@ class SageMakerModelBackend(BaseBackend):
self.model_package_groups: Dict[str, ModelPackageGroup] = {} self.model_package_groups: Dict[str, ModelPackageGroup] = {}
self.model_packages: Dict[str, ModelPackage] = {} self.model_packages: Dict[str, ModelPackage] = {}
self.model_package_name_mapping: Dict[str, str] = {} self.model_package_name_mapping: Dict[str, str] = {}
self.feature_groups: Dict[str, FeatureGroup] = {}
@staticmethod @staticmethod
def default_vpc_endpoint_service( def default_vpc_endpoint_service(
@ -3464,6 +3519,44 @@ class SageMakerModelBackend(BaseBackend):
self.model_packages[model_package.model_package_arn] = model_package self.model_packages[model_package.model_package_arn] = model_package
return model_package.model_package_arn return model_package.model_package_arn
def create_feature_group(
self,
feature_group_name: str,
record_identifier_feature_name: str,
event_time_feature_name: str,
feature_definitions: List[Dict[str, str]],
offline_store_config: Dict[str, Any],
role_arn: str,
tags: Any,
) -> str:
feature_group = FeatureGroup(
feature_group_name=feature_group_name,
record_identifier_feature_name=record_identifier_feature_name,
event_time_feature_name=event_time_feature_name,
feature_definitions=feature_definitions,
offline_store_config=offline_store_config,
role_arn=role_arn,
region_name=self.region_name,
account_id=self.account_id,
tags=tags,
)
self.feature_groups[feature_group.feature_group_arn] = feature_group
return feature_group.feature_group_arn
def describe_feature_group(
self,
feature_group_name: str,
) -> Dict[str, Any]:
feature_group_arn = arn_formatter(
region_name=self.region_name,
account_id=self.account_id,
_type="feature-group",
_id=f"{feature_group_name.lower()}",
)
feature_group = self.feature_groups[feature_group_arn]
return feature_group.describe()
class FakeExperiment(BaseObject): class FakeExperiment(BaseObject):
def __init__( def __init__(

View File

@ -959,3 +959,23 @@ class SageMakerResponse(BaseResponse):
tags=tags, tags=tags,
) )
return json.dumps(dict(ModelPackageGroupArn=model_package_group_arn)) return json.dumps(dict(ModelPackageGroupArn=model_package_group_arn))
def create_feature_group(self) -> str:
feature_group_arn = self.sagemaker_backend.create_feature_group(
feature_group_name=self._get_param("FeatureGroupName"),
record_identifier_feature_name=self._get_param(
"RecordIdentifierFeatureName"
),
event_time_feature_name=self._get_param("EventTimeFeatureName"),
feature_definitions=self._get_param("FeatureDefinitions"),
offline_store_config=self._get_param("OfflineStoreConfig"),
role_arn=self._get_param("RoleArn"),
tags=self._get_param("Tags"),
)
return json.dumps(dict(FeatureGroupArn=feature_group_arn))
def describe_feature_group(self) -> str:
resp = self.sagemaker_backend.describe_feature_group(
feature_group_name=self._get_param("FeatureGroupName"),
)
return json.dumps(resp)

View File

@ -0,0 +1,85 @@
"""Unit tests for sagemaker-supported APIs."""
import re
from datetime import datetime
import boto3
from moto import mock_sagemaker
# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
@mock_sagemaker
def test_create_feature_group():
client = boto3.client("sagemaker", region_name="us-east-2")
resp = client.create_feature_group(
FeatureGroupName="some-feature-group-name",
RecordIdentifierFeatureName="some_record_identifier",
EventTimeFeatureName="EventTime",
FeatureDefinitions=[
{"FeatureName": "some_feature", "FeatureType": "String"},
{"FeatureName": "EventTime", "FeatureType": "Fractional"},
{"FeatureName": "some_record_identifier", "FeatureType": "String"},
],
RoleArn="arn:aws:iam::123456789012:role/AWSFeatureStoreAccess",
OfflineStoreConfig={
"DisableGlueTableCreation": False,
"S3StorageConfig": {"S3Uri": "s3://mybucket"},
},
)
assert (
resp["FeatureGroupArn"]
== "arn:aws:sagemaker:us-east-2:123456789012:feature-group/some-feature-group-name"
)
@mock_sagemaker
def test_describe_feature_group():
client = boto3.client("sagemaker", region_name="us-east-2")
feature_group_name = "some-feature-group-name"
record_identifier_feature_name = "some_record_identifier"
event_time_feature_name = "EventTime"
role_arn = "arn:aws:iam::123456789012:role/AWSFeatureStoreAccess"
feature_definitions = [
{"FeatureName": "some_feature", "FeatureType": "String"},
{"FeatureName": event_time_feature_name, "FeatureType": "Fractional"},
{"FeatureName": record_identifier_feature_name, "FeatureType": "String"},
]
client.create_feature_group(
FeatureGroupName=feature_group_name,
RecordIdentifierFeatureName=record_identifier_feature_name,
EventTimeFeatureName=event_time_feature_name,
FeatureDefinitions=feature_definitions,
RoleArn=role_arn,
OfflineStoreConfig={
"DisableGlueTableCreation": False,
"S3StorageConfig": {"S3Uri": "s3://mybucket"},
},
)
resp = client.describe_feature_group(FeatureGroupName=feature_group_name)
assert resp["FeatureGroupName"] == feature_group_name
assert (
resp["FeatureGroupArn"]
== "arn:aws:sagemaker:us-east-2:123456789012:feature-group/some-feature-group-name"
)
assert resp["RecordIdentifierFeatureName"] == record_identifier_feature_name
assert resp["EventTimeFeatureName"] == event_time_feature_name
assert resp["FeatureDefinitions"] == feature_definitions
assert resp["RoleArn"] == role_arn
assert re.match(
r"^some_feature_group_name_[0-9]+$",
resp["OfflineStoreConfig"]["DataCatalogConfig"]["TableName"],
)
assert (
resp["OfflineStoreConfig"]["DataCatalogConfig"]["Catalog"] == "AwsDataCatalog"
)
assert (
resp["OfflineStoreConfig"]["DataCatalogConfig"]["Database"]
== "sagemaker_featurestore"
)
assert resp["OfflineStoreConfig"]["S3StorageConfig"]["S3Uri"] == "s3://mybucket"
assert isinstance(resp["CreationTime"], datetime)
assert resp["FeatureGroupStatus"] == "Created"