Add SageMaker Feature Group (#7227)
This commit is contained in:
parent
792956e959
commit
8d91d09b42
@ -980,6 +980,60 @@ class ModelPackageGroup(BaseObject):
|
||||
return {k: v for k, v in response_object.items() if k in response_values}
|
||||
|
||||
|
||||
class FeatureGroup(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
region_name: str,
|
||||
account_id: str,
|
||||
feature_group_name: str,
|
||||
record_identifier_feature_name: str,
|
||||
event_time_feature_name: str,
|
||||
feature_definitions: List[Dict[str, str]],
|
||||
offline_store_config: Dict[str, Any],
|
||||
role_arn: str,
|
||||
tags: Optional[List[Dict[str, str]]] = None,
|
||||
) -> None:
|
||||
self.feature_group_name = feature_group_name
|
||||
self.record_identifier_feature_name = record_identifier_feature_name
|
||||
self.event_time_feature_name = event_time_feature_name
|
||||
self.feature_definitions = feature_definitions
|
||||
|
||||
table_name = (
|
||||
f"{feature_group_name.replace('-','_')}_{int(datetime.now().timestamp())}"
|
||||
)
|
||||
offline_store_config["DataCatalogConfig"] = {
|
||||
"TableName": table_name,
|
||||
"Catalog": "AwsDataCatalog",
|
||||
"Database": "sagemaker_featurestore",
|
||||
}
|
||||
|
||||
self.offline_store_config = offline_store_config
|
||||
self.role_arn = role_arn
|
||||
|
||||
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.feature_group_arn = arn_formatter(
|
||||
region_name=region_name,
|
||||
account_id=account_id,
|
||||
_type="feature-group",
|
||||
_id=f"{self.feature_group_name.lower()}",
|
||||
)
|
||||
self.tags = tags
|
||||
|
||||
def describe(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"FeatureGroupArn": self.feature_group_arn,
|
||||
"FeatureGroupName": self.feature_group_name,
|
||||
"RecordIdentifierFeatureName": self.record_identifier_feature_name,
|
||||
"EventTimeFeatureName": self.event_time_feature_name,
|
||||
"FeatureDefinitions": self.feature_definitions,
|
||||
"CreationTime": self.creation_time,
|
||||
"OfflineStoreConfig": self.offline_store_config,
|
||||
"RoleArn": self.role_arn,
|
||||
"ThroughputConfig": {"ThroughputMode": "OnDemand"},
|
||||
"FeatureGroupStatus": "Created",
|
||||
}
|
||||
|
||||
|
||||
class ModelPackage(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1768,6 +1822,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
self.model_package_groups: Dict[str, ModelPackageGroup] = {}
|
||||
self.model_packages: Dict[str, ModelPackage] = {}
|
||||
self.model_package_name_mapping: Dict[str, str] = {}
|
||||
self.feature_groups: Dict[str, FeatureGroup] = {}
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(
|
||||
@ -3464,6 +3519,44 @@ class SageMakerModelBackend(BaseBackend):
|
||||
self.model_packages[model_package.model_package_arn] = model_package
|
||||
return model_package.model_package_arn
|
||||
|
||||
def create_feature_group(
|
||||
self,
|
||||
feature_group_name: str,
|
||||
record_identifier_feature_name: str,
|
||||
event_time_feature_name: str,
|
||||
feature_definitions: List[Dict[str, str]],
|
||||
offline_store_config: Dict[str, Any],
|
||||
role_arn: str,
|
||||
tags: Any,
|
||||
) -> str:
|
||||
feature_group = FeatureGroup(
|
||||
feature_group_name=feature_group_name,
|
||||
record_identifier_feature_name=record_identifier_feature_name,
|
||||
event_time_feature_name=event_time_feature_name,
|
||||
feature_definitions=feature_definitions,
|
||||
offline_store_config=offline_store_config,
|
||||
role_arn=role_arn,
|
||||
region_name=self.region_name,
|
||||
account_id=self.account_id,
|
||||
tags=tags,
|
||||
)
|
||||
self.feature_groups[feature_group.feature_group_arn] = feature_group
|
||||
return feature_group.feature_group_arn
|
||||
|
||||
def describe_feature_group(
|
||||
self,
|
||||
feature_group_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
feature_group_arn = arn_formatter(
|
||||
region_name=self.region_name,
|
||||
account_id=self.account_id,
|
||||
_type="feature-group",
|
||||
_id=f"{feature_group_name.lower()}",
|
||||
)
|
||||
|
||||
feature_group = self.feature_groups[feature_group_arn]
|
||||
return feature_group.describe()
|
||||
|
||||
|
||||
class FakeExperiment(BaseObject):
|
||||
def __init__(
|
||||
|
@ -959,3 +959,23 @@ class SageMakerResponse(BaseResponse):
|
||||
tags=tags,
|
||||
)
|
||||
return json.dumps(dict(ModelPackageGroupArn=model_package_group_arn))
|
||||
|
||||
def create_feature_group(self) -> str:
|
||||
feature_group_arn = self.sagemaker_backend.create_feature_group(
|
||||
feature_group_name=self._get_param("FeatureGroupName"),
|
||||
record_identifier_feature_name=self._get_param(
|
||||
"RecordIdentifierFeatureName"
|
||||
),
|
||||
event_time_feature_name=self._get_param("EventTimeFeatureName"),
|
||||
feature_definitions=self._get_param("FeatureDefinitions"),
|
||||
offline_store_config=self._get_param("OfflineStoreConfig"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
tags=self._get_param("Tags"),
|
||||
)
|
||||
return json.dumps(dict(FeatureGroupArn=feature_group_arn))
|
||||
|
||||
def describe_feature_group(self) -> str:
|
||||
resp = self.sagemaker_backend.describe_feature_group(
|
||||
feature_group_name=self._get_param("FeatureGroupName"),
|
||||
)
|
||||
return json.dumps(resp)
|
||||
|
85
tests/test_sagemaker/test_sagemaker_feature_groups.py
Normal file
85
tests/test_sagemaker/test_sagemaker_feature_groups.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""Unit tests for sagemaker-supported APIs."""
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
import boto3
|
||||
|
||||
from moto import mock_sagemaker
|
||||
|
||||
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_feature_group():
|
||||
client = boto3.client("sagemaker", region_name="us-east-2")
|
||||
resp = client.create_feature_group(
|
||||
FeatureGroupName="some-feature-group-name",
|
||||
RecordIdentifierFeatureName="some_record_identifier",
|
||||
EventTimeFeatureName="EventTime",
|
||||
FeatureDefinitions=[
|
||||
{"FeatureName": "some_feature", "FeatureType": "String"},
|
||||
{"FeatureName": "EventTime", "FeatureType": "Fractional"},
|
||||
{"FeatureName": "some_record_identifier", "FeatureType": "String"},
|
||||
],
|
||||
RoleArn="arn:aws:iam::123456789012:role/AWSFeatureStoreAccess",
|
||||
OfflineStoreConfig={
|
||||
"DisableGlueTableCreation": False,
|
||||
"S3StorageConfig": {"S3Uri": "s3://mybucket"},
|
||||
},
|
||||
)
|
||||
|
||||
assert (
|
||||
resp["FeatureGroupArn"]
|
||||
== "arn:aws:sagemaker:us-east-2:123456789012:feature-group/some-feature-group-name"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_feature_group():
|
||||
client = boto3.client("sagemaker", region_name="us-east-2")
|
||||
feature_group_name = "some-feature-group-name"
|
||||
record_identifier_feature_name = "some_record_identifier"
|
||||
event_time_feature_name = "EventTime"
|
||||
role_arn = "arn:aws:iam::123456789012:role/AWSFeatureStoreAccess"
|
||||
feature_definitions = [
|
||||
{"FeatureName": "some_feature", "FeatureType": "String"},
|
||||
{"FeatureName": event_time_feature_name, "FeatureType": "Fractional"},
|
||||
{"FeatureName": record_identifier_feature_name, "FeatureType": "String"},
|
||||
]
|
||||
client.create_feature_group(
|
||||
FeatureGroupName=feature_group_name,
|
||||
RecordIdentifierFeatureName=record_identifier_feature_name,
|
||||
EventTimeFeatureName=event_time_feature_name,
|
||||
FeatureDefinitions=feature_definitions,
|
||||
RoleArn=role_arn,
|
||||
OfflineStoreConfig={
|
||||
"DisableGlueTableCreation": False,
|
||||
"S3StorageConfig": {"S3Uri": "s3://mybucket"},
|
||||
},
|
||||
)
|
||||
resp = client.describe_feature_group(FeatureGroupName=feature_group_name)
|
||||
|
||||
assert resp["FeatureGroupName"] == feature_group_name
|
||||
assert (
|
||||
resp["FeatureGroupArn"]
|
||||
== "arn:aws:sagemaker:us-east-2:123456789012:feature-group/some-feature-group-name"
|
||||
)
|
||||
assert resp["RecordIdentifierFeatureName"] == record_identifier_feature_name
|
||||
assert resp["EventTimeFeatureName"] == event_time_feature_name
|
||||
assert resp["FeatureDefinitions"] == feature_definitions
|
||||
assert resp["RoleArn"] == role_arn
|
||||
assert re.match(
|
||||
r"^some_feature_group_name_[0-9]+$",
|
||||
resp["OfflineStoreConfig"]["DataCatalogConfig"]["TableName"],
|
||||
)
|
||||
assert (
|
||||
resp["OfflineStoreConfig"]["DataCatalogConfig"]["Catalog"] == "AwsDataCatalog"
|
||||
)
|
||||
assert (
|
||||
resp["OfflineStoreConfig"]["DataCatalogConfig"]["Database"]
|
||||
== "sagemaker_featurestore"
|
||||
)
|
||||
assert resp["OfflineStoreConfig"]["S3StorageConfig"]["S3Uri"] == "s3://mybucket"
|
||||
assert isinstance(resp["CreationTime"], datetime)
|
||||
assert resp["FeatureGroupStatus"] == "Created"
|
Loading…
Reference in New Issue
Block a user