Skip to content

Commit

Permalink
[PNE-6367] Add support for feature effects.
Browse files Browse the repository at this point in the history
The payload comes in as a condensed format, so it's expanded in order to
constructed nested lists of objects for clarity and ease of use.

Additionally, the hierarchy of data is flipped to more closely match how
it will be used by our customers. To that end, 'as_dict' is provided to
ease importing it into a pandas DataFrame for whatever processing and
analysis the customer desires.
  • Loading branch information
anoto-moniz committed Nov 22, 2024
1 parent 455e12e commit 8a16228
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.11.6"
__version__ = "3.12.0"
75 changes: 75 additions & 0 deletions src/citrine/informatics/feature_effects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Dict
from uuid import UUID

from citrine._rest.resource import Resource
from citrine._serialization import properties


class ShapleyMaterial(Resource):
"""The feature effect of a material."""

material_id = properties.UUID('material_id', serializable=False)
value = properties.Float('value', serializable=False)


class ShapleyFeature(Resource):
"""All feature effects for this feature by material."""

feature = properties.String('feature', serializable=False)
materials = properties.List(properties.Object(ShapleyMaterial), 'materials',
serializable=False)

@property
def material_dict(self) -> Dict[UUID, float]:
"""Presents the feature's effects as a dictionary by material."""
return {material.material_id: material.value for material in self.materials}


class ShapleyOutput(Resource):
"""All feature effects for this output by feature."""

output = properties.String('output', serializable=False)
features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False)

@property
def feature_dict(self) -> Dict[str, Dict[UUID, float]]:
"""Presents the output's feature effects as a dictionary by feature."""
return {feature.feature: feature.material_dict for feature in self.features}


class FeatureEffects(Resource):
"""Captures information about the feature effects associated with a predictor."""

predictor_id = properties.UUID('metadata.predictor_id', serializable=False)
predictor_version = properties.Integer('metadata.predictor_version', serializable=False)
status = properties.String('metadata.status', serializable=False)
failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason',
serializable=False)

outputs = properties.List(properties.Object(ShapleyOutput), 'resultobj', serializable=False)

@classmethod
def _pre_build(cls, data: dict) -> Dict:
shapley = data["result"]
material_ids = shapley["materials"]

outputs = []
for output, feature_dict in shapley["outputs"].items():
features = []
for feature, values in feature_dict.items():
items = zip(material_ids, values)
materials = [{"material_id": mid, "value": value} for mid, value in items]
features.append({
"feature": feature,
"materials": materials
})

outputs.append({"output": output, "features": features})

data["resultobj"] = outputs
return data

@property
def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]:
"""Presents the feature effects as a dictionary by output."""
return {output.output: output.feature_dict for output in self.outputs}
11 changes: 10 additions & 1 deletion src/citrine/informatics/predictors/graph_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from citrine._session import Session
from citrine._utils.functions import format_escaped_url
from citrine.informatics.data_sources import DataSource
from citrine.informatics.feature_effects import FeatureEffects
from citrine.informatics.predictors.single_predict_request import SinglePredictRequest
from citrine.informatics.predictors.single_prediction import SinglePrediction
from citrine.informatics.predictors import PredictorNode, Predictor
from citrine.informatics.reports import Report
from citrine.resources.report import ReportResource

__all__ = ['GraphPredictor']
Expand Down Expand Up @@ -104,7 +106,7 @@ def wrap_instance(predictor_data: dict) -> dict:
}

@property
def report(self):
def report(self) -> Report:
"""Fetch the predictor report."""
if self.uid is None or self._session is None or self._project_id is None \
or getattr(self, "version", None) is None:
Expand All @@ -113,6 +115,13 @@ def report(self):
report_resource = ReportResource(self._project_id, self._session)
return report_resource.get(predictor_id=self.uid, predictor_version=self.version)

@property
def feature_effects(self) -> FeatureEffects:
"""Retrieve the feature effects for all outputs in the predictor's training data.."""
path = self._path() + '/shapley/query'
response = self._session.post_resource(path, {}, version=self._api_version)
return FeatureEffects.build(response)

def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction:
"""Make a one-off prediction with this predictor."""
path = self._path() + '/predict'
Expand Down
1 change: 1 addition & 0 deletions src/citrine/resources/table_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class TableConfig(Resource["TableConfig"]):
The query used to define the materials underpinning this table
generation_algorithm: TableFromGemdQueryAlgorithm
Which algorithm was used to generate the config based on the GemdQuery results
"""

# FIXME (DML): rename this (this is dependent on the server side)
Expand Down
27 changes: 25 additions & 2 deletions tests/informatics/test_predictors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tests for citrine.informatics.predictors."""
import uuid
import pytest
import mock
import pytest
import uuid
from random import random

from citrine.informatics.data_sources import GemTableDataSource
from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \
Expand All @@ -12,6 +13,10 @@
from citrine.informatics.predictors.single_prediction import SinglePrediction
from citrine.informatics.design_candidate import DesignMaterial

from tests.utils.factories import FeatureEffectsResponseFactory
from tests.utils.session import FakeCall, FakeSession


w = IntegerDescriptor("w", lower_bound=0, upper_bound=100)
x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="")
Expand Down Expand Up @@ -485,3 +490,21 @@ def test_single_predict(graph_predictor):
prediction_out = graph_predictor.predict(request)
assert prediction_out.dump() == prediction_in.dump()
assert session.post_resource.call_count == 1


def test_feature_effects(graph_predictor):
feature_effects_response = FeatureEffectsResponseFactory()
feature_effects_as_dict = feature_effects_response.pop("_result_as_dict")

session = FakeSession()
session.set_response(feature_effects_response)

graph_predictor._session = session
graph_predictor._project_id = uuid.uuid4()

fe = graph_predictor.feature_effects

expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \
f"/versions/{graph_predictor.version}/shapley/query"
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
assert fe.as_dict == feature_effects_as_dict
40 changes: 40 additions & 0 deletions tests/utils/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,43 @@ class AnalysisWorkflowEntityDataFactory(factory.DictFactory):
id = factory.Faker('uuid4')
data = factory.SubFactory(AnalysisWorkflowDataDataFactory)
metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory)


class FeatureEffectsResponseResultFactory(factory.DictFactory):
materials = factory.List([
factory.Faker('uuid4', cast_to=None),
factory.Faker('uuid4', cast_to=None),
factory.Faker('uuid4', cast_to=None)
])
outputs = factory.Dict({
"output1": factory.Dict({
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
}),
"output2": factory.Dict({
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]),
"feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
})
})

class FeatureEffectsMetadataFactory(factory.DictFactory):
predictor_id = factory.Faker('uuid4')
predictor_version = factory.Faker('random_digit_not_null')
created = factory.SubFactory(UserTimestampDataFactory)
updated = factory.SubFactory(UserTimestampDataFactory)
status = 'SUCCEEDED'


class FeatureEffectsResponseFactory(factory.DictFactory):
query = {} # Presently, querying from the SDK is not allowed.
metadata = factory.SubFactory(FeatureEffectsMetadataFactory)
result = factory.SubFactory(FeatureEffectsResponseResultFactory)
_result_as_dict = factory.LazyAttribute(lambda obj: _expand_condensed(obj.result))


def _expand_condensed(result_obj):
whole_dict = {}
for output, feature_dict in result_obj["outputs"].items():
whole_dict[output] = {}
for feature, values in feature_dict.items():
whole_dict[output][feature] = dict(zip(result_obj["materials"], values))
return whole_dict

0 comments on commit 8a16228

Please sign in to comment.