Skip to content

Commit

Permalink
Merge pull request #982 from CitrineInformatics/feature/pne-6367-feat…
Browse files Browse the repository at this point in the history
…ure-effects

[PNE-6367] Add support for feature effects.
  • Loading branch information
anoto-moniz authored Dec 11, 2024
2 parents 455e12e + 623a307 commit 556ce5f
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.11.6"
__version__ = "3.12.0"
76 changes: 76 additions & 0 deletions src/citrine/informatics/feature_effects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Dict
from uuid import UUID

from citrine._rest.resource import Resource
from citrine._serialization import properties


class ShapleyMaterial(Resource):
"""The feature effect of a material."""

material_id = properties.UUID('material_id', serializable=False)
value = properties.Float('value', serializable=False)


class ShapleyFeature(Resource):
"""All feature effects for this feature by material."""

feature = properties.String('feature', serializable=False)
materials = properties.List(properties.Object(ShapleyMaterial), 'materials',
serializable=False)

@property
def material_dict(self) -> Dict[UUID, float]:
"""Presents the feature's effects as a dictionary by material."""
return {material.material_id: material.value for material in self.materials}


class ShapleyOutput(Resource):
"""All feature effects for this output by feature."""

output = properties.String('output', serializable=False)
features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False)

@property
def feature_dict(self) -> Dict[str, Dict[UUID, float]]:
"""Presents the output's feature effects as a dictionary by feature."""
return {feature.feature: feature.material_dict for feature in self.features}


class FeatureEffects(Resource):
"""Captures information about the feature effects associated with a predictor."""

predictor_id = properties.UUID('metadata.predictor_id', serializable=False)
predictor_version = properties.Integer('metadata.predictor_version', serializable=False)
status = properties.String('metadata.status', serializable=False)
failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason',
serializable=False)

outputs = properties.Optional(properties.List(properties.Object(ShapleyOutput)), 'resultobj',
serializable=False)

@classmethod
def _pre_build(cls, data: dict) -> Dict:
shapley = data["result"]
material_ids = shapley["materials"]

outputs = []
for output, feature_dict in shapley["outputs"].items():
features = []
for feature, values in feature_dict.items():
items = zip(material_ids, values)
materials = [{"material_id": mid, "value": value} for mid, value in items]
features.append({
"feature": feature,
"materials": materials
})

outputs.append({"output": output, "features": features})

data["resultobj"] = outputs
return data

@property
def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]:
"""Presents the feature effects as a dictionary by output."""
return {output.output: output.feature_dict for output in self.outputs}
11 changes: 10 additions & 1 deletion src/citrine/informatics/predictors/graph_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from citrine._session import Session
from citrine._utils.functions import format_escaped_url
from citrine.informatics.data_sources import DataSource
from citrine.informatics.feature_effects import FeatureEffects
from citrine.informatics.predictors.single_predict_request import SinglePredictRequest
from citrine.informatics.predictors.single_prediction import SinglePrediction
from citrine.informatics.predictors import PredictorNode, Predictor
from citrine.informatics.reports import Report
from citrine.resources.report import ReportResource

__all__ = ['GraphPredictor']
Expand Down Expand Up @@ -104,7 +106,7 @@ def wrap_instance(predictor_data: dict) -> dict:
}

@property
def report(self):
def report(self) -> Report:
"""Fetch the predictor report."""
if self.uid is None or self._session is None or self._project_id is None \
or getattr(self, "version", None) is None:
Expand All @@ -113,6 +115,13 @@ def report(self):
report_resource = ReportResource(self._project_id, self._session)
return report_resource.get(predictor_id=self.uid, predictor_version=self.version)

@property
def feature_effects(self) -> FeatureEffects:
"""Retrieve the feature effects for all outputs in the predictor's training data.."""
path = self._path() + '/shapley/query'
response = self._session.post_resource(path, {}, version=self._api_version)
return FeatureEffects.build(response)

def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction:
"""Make a one-off prediction with this predictor."""
path = self._path() + '/predict'
Expand Down
1 change: 1 addition & 0 deletions src/citrine/resources/table_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class TableConfig(Resource["TableConfig"]):
The query used to define the materials underpinning this table
generation_algorithm: TableFromGemdQueryAlgorithm
Which algorithm was used to generate the config based on the GemdQuery results
"""

# FIXME (DML): rename this (this is dependent on the server side)
Expand Down
27 changes: 25 additions & 2 deletions tests/informatics/test_predictors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tests for citrine.informatics.predictors."""
import uuid
import pytest
import mock
import pytest
import uuid
from random import random

from citrine.informatics.data_sources import GemTableDataSource
from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \
Expand All @@ -12,6 +13,10 @@
from citrine.informatics.predictors.single_prediction import SinglePrediction
from citrine.informatics.design_candidate import DesignMaterial

from tests.utils.factories import FeatureEffectsResponseFactory
from tests.utils.session import FakeCall, FakeSession


w = IntegerDescriptor("w", lower_bound=0, upper_bound=100)
x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="")
Expand Down Expand Up @@ -485,3 +490,21 @@ def test_single_predict(graph_predictor):
prediction_out = graph_predictor.predict(request)
assert prediction_out.dump() == prediction_in.dump()
assert session.post_resource.call_count == 1


def test_feature_effects(graph_predictor):
feature_effects_response = FeatureEffectsResponseFactory()
feature_effects_as_dict = feature_effects_response.pop("_result_as_dict")

session = FakeSession()
session.set_response(feature_effects_response)

graph_predictor._session = session
graph_predictor._project_id = uuid.uuid4()

fe = graph_predictor.feature_effects

expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \
f"/versions/{graph_predictor.version}/shapley/query"
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
assert fe.as_dict == feature_effects_as_dict
40 changes: 40 additions & 0 deletions tests/utils/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,43 @@ class AnalysisWorkflowEntityDataFactory(factory.DictFactory):
id = factory.Faker('uuid4')
data = factory.SubFactory(AnalysisWorkflowDataDataFactory)
metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory)


class FeatureEffectsResponseResultFactory(factory.DictFactory):
materials = factory.List([
factory.Faker('uuid4', cast_to=None),
factory.Faker('uuid4', cast_to=None),
factory.Faker('uuid4', cast_to=None)
])
outputs = factory.Dict({
"output1": factory.Dict({
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
}),
"output2": factory.Dict({
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]),
"feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
})
})

class FeatureEffectsMetadataFactory(factory.DictFactory):
predictor_id = factory.Faker('uuid4')
predictor_version = factory.Faker('random_digit_not_null')
created = factory.SubFactory(UserTimestampDataFactory)
updated = factory.SubFactory(UserTimestampDataFactory)
status = 'SUCCEEDED'


class FeatureEffectsResponseFactory(factory.DictFactory):
query = {} # Presently, querying from the SDK is not allowed.
metadata = factory.SubFactory(FeatureEffectsMetadataFactory)
result = factory.SubFactory(FeatureEffectsResponseResultFactory)
_result_as_dict = factory.LazyAttribute(lambda obj: _expand_condensed(obj.result))


def _expand_condensed(result_obj):
whole_dict = {}
for output, feature_dict in result_obj["outputs"].items():
whole_dict[output] = {}
for feature, values in feature_dict.items():
whole_dict[output][feature] = dict(zip(result_obj["materials"], values))
return whole_dict

0 comments on commit 556ce5f

Please sign in to comment.