Skip to content

Commit

Permalink
Add SageMaker Feature Group (#7227)
Browse files Browse the repository at this point in the history
  • Loading branch information
bogdangi authored Jan 21, 2024
1 parent 792956e commit 8d91d09
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 0 deletions.
93 changes: 93 additions & 0 deletions moto/sagemaker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,60 @@ def gen_response_object(self) -> Dict[str, Any]:
return {k: v for k, v in response_object.items() if k in response_values}


class FeatureGroup(BaseObject):
def __init__(
self,
region_name: str,
account_id: str,
feature_group_name: str,
record_identifier_feature_name: str,
event_time_feature_name: str,
feature_definitions: List[Dict[str, str]],
offline_store_config: Dict[str, Any],
role_arn: str,
tags: Optional[List[Dict[str, str]]] = None,
) -> None:
self.feature_group_name = feature_group_name
self.record_identifier_feature_name = record_identifier_feature_name
self.event_time_feature_name = event_time_feature_name
self.feature_definitions = feature_definitions

table_name = (
f"{feature_group_name.replace('-','_')}_{int(datetime.now().timestamp())}"
)
offline_store_config["DataCatalogConfig"] = {
"TableName": table_name,
"Catalog": "AwsDataCatalog",
"Database": "sagemaker_featurestore",
}

self.offline_store_config = offline_store_config
self.role_arn = role_arn

self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.feature_group_arn = arn_formatter(
region_name=region_name,
account_id=account_id,
_type="feature-group",
_id=f"{self.feature_group_name.lower()}",
)
self.tags = tags

def describe(self) -> Dict[str, Any]:
return {
"FeatureGroupArn": self.feature_group_arn,
"FeatureGroupName": self.feature_group_name,
"RecordIdentifierFeatureName": self.record_identifier_feature_name,
"EventTimeFeatureName": self.event_time_feature_name,
"FeatureDefinitions": self.feature_definitions,
"CreationTime": self.creation_time,
"OfflineStoreConfig": self.offline_store_config,
"RoleArn": self.role_arn,
"ThroughputConfig": {"ThroughputMode": "OnDemand"},
"FeatureGroupStatus": "Created",
}


class ModelPackage(BaseObject):
def __init__(
self,
Expand Down Expand Up @@ -1768,6 +1822,7 @@ def __init__(self, region_name: str, account_id: str):
self.model_package_groups: Dict[str, ModelPackageGroup] = {}
self.model_packages: Dict[str, ModelPackage] = {}
self.model_package_name_mapping: Dict[str, str] = {}
self.feature_groups: Dict[str, FeatureGroup] = {}

@staticmethod
def default_vpc_endpoint_service(
Expand Down Expand Up @@ -3464,6 +3519,44 @@ def create_model_package(
self.model_packages[model_package.model_package_arn] = model_package
return model_package.model_package_arn

def create_feature_group(
self,
feature_group_name: str,
record_identifier_feature_name: str,
event_time_feature_name: str,
feature_definitions: List[Dict[str, str]],
offline_store_config: Dict[str, Any],
role_arn: str,
tags: Any,
) -> str:
feature_group = FeatureGroup(
feature_group_name=feature_group_name,
record_identifier_feature_name=record_identifier_feature_name,
event_time_feature_name=event_time_feature_name,
feature_definitions=feature_definitions,
offline_store_config=offline_store_config,
role_arn=role_arn,
region_name=self.region_name,
account_id=self.account_id,
tags=tags,
)
self.feature_groups[feature_group.feature_group_arn] = feature_group
return feature_group.feature_group_arn

def describe_feature_group(
self,
feature_group_name: str,
) -> Dict[str, Any]:
feature_group_arn = arn_formatter(
region_name=self.region_name,
account_id=self.account_id,
_type="feature-group",
_id=f"{feature_group_name.lower()}",
)

feature_group = self.feature_groups[feature_group_arn]
return feature_group.describe()


class FakeExperiment(BaseObject):
def __init__(
Expand Down
20 changes: 20 additions & 0 deletions moto/sagemaker/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,23 @@ def create_model_package_group(self) -> str:
tags=tags,
)
return json.dumps(dict(ModelPackageGroupArn=model_package_group_arn))

def create_feature_group(self) -> str:
feature_group_arn = self.sagemaker_backend.create_feature_group(
feature_group_name=self._get_param("FeatureGroupName"),
record_identifier_feature_name=self._get_param(
"RecordIdentifierFeatureName"
),
event_time_feature_name=self._get_param("EventTimeFeatureName"),
feature_definitions=self._get_param("FeatureDefinitions"),
offline_store_config=self._get_param("OfflineStoreConfig"),
role_arn=self._get_param("RoleArn"),
tags=self._get_param("Tags"),
)
return json.dumps(dict(FeatureGroupArn=feature_group_arn))

def describe_feature_group(self) -> str:
resp = self.sagemaker_backend.describe_feature_group(
feature_group_name=self._get_param("FeatureGroupName"),
)
return json.dumps(resp)
85 changes: 85 additions & 0 deletions tests/test_sagemaker/test_sagemaker_feature_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Unit tests for sagemaker-supported APIs."""
import re
from datetime import datetime

import boto3

from moto import mock_sagemaker

# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html


@mock_sagemaker
def test_create_feature_group():
client = boto3.client("sagemaker", region_name="us-east-2")
resp = client.create_feature_group(
FeatureGroupName="some-feature-group-name",
RecordIdentifierFeatureName="some_record_identifier",
EventTimeFeatureName="EventTime",
FeatureDefinitions=[
{"FeatureName": "some_feature", "FeatureType": "String"},
{"FeatureName": "EventTime", "FeatureType": "Fractional"},
{"FeatureName": "some_record_identifier", "FeatureType": "String"},
],
RoleArn="arn:aws:iam::123456789012:role/AWSFeatureStoreAccess",
OfflineStoreConfig={
"DisableGlueTableCreation": False,
"S3StorageConfig": {"S3Uri": "s3://mybucket"},
},
)

assert (
resp["FeatureGroupArn"]
== "arn:aws:sagemaker:us-east-2:123456789012:feature-group/some-feature-group-name"
)


@mock_sagemaker
def test_describe_feature_group():
client = boto3.client("sagemaker", region_name="us-east-2")
feature_group_name = "some-feature-group-name"
record_identifier_feature_name = "some_record_identifier"
event_time_feature_name = "EventTime"
role_arn = "arn:aws:iam::123456789012:role/AWSFeatureStoreAccess"
feature_definitions = [
{"FeatureName": "some_feature", "FeatureType": "String"},
{"FeatureName": event_time_feature_name, "FeatureType": "Fractional"},
{"FeatureName": record_identifier_feature_name, "FeatureType": "String"},
]
client.create_feature_group(
FeatureGroupName=feature_group_name,
RecordIdentifierFeatureName=record_identifier_feature_name,
EventTimeFeatureName=event_time_feature_name,
FeatureDefinitions=feature_definitions,
RoleArn=role_arn,
OfflineStoreConfig={
"DisableGlueTableCreation": False,
"S3StorageConfig": {"S3Uri": "s3://mybucket"},
},
)
resp = client.describe_feature_group(FeatureGroupName=feature_group_name)

assert resp["FeatureGroupName"] == feature_group_name
assert (
resp["FeatureGroupArn"]
== "arn:aws:sagemaker:us-east-2:123456789012:feature-group/some-feature-group-name"
)
assert resp["RecordIdentifierFeatureName"] == record_identifier_feature_name
assert resp["EventTimeFeatureName"] == event_time_feature_name
assert resp["FeatureDefinitions"] == feature_definitions
assert resp["RoleArn"] == role_arn
assert re.match(
r"^some_feature_group_name_[0-9]+$",
resp["OfflineStoreConfig"]["DataCatalogConfig"]["TableName"],
)
assert (
resp["OfflineStoreConfig"]["DataCatalogConfig"]["Catalog"] == "AwsDataCatalog"
)
assert (
resp["OfflineStoreConfig"]["DataCatalogConfig"]["Database"]
== "sagemaker_featurestore"
)
assert resp["OfflineStoreConfig"]["S3StorageConfig"]["S3Uri"] == "s3://mybucket"
assert isinstance(resp["CreationTime"], datetime)
assert resp["FeatureGroupStatus"] == "Created"

0 comments on commit 8d91d09

Please sign in to comment.