diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index 92bbba870..d1a7f1e0d 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.11.6" +__version__ = "3.12.0" diff --git a/src/citrine/informatics/feature_effects.py b/src/citrine/informatics/feature_effects.py new file mode 100644 index 000000000..947ba021d --- /dev/null +++ b/src/citrine/informatics/feature_effects.py @@ -0,0 +1,76 @@ +from typing import Dict +from uuid import UUID + +from citrine._rest.resource import Resource +from citrine._serialization import properties + + +class ShapleyMaterial(Resource): + """The feature effect of a material.""" + + material_id = properties.UUID('material_id', serializable=False) + value = properties.Float('value', serializable=False) + + +class ShapleyFeature(Resource): + """All feature effects for this feature by material.""" + + feature = properties.String('feature', serializable=False) + materials = properties.List(properties.Object(ShapleyMaterial), 'materials', + serializable=False) + + @property + def material_dict(self) -> Dict[UUID, float]: + """Presents the feature's effects as a dictionary by material.""" + return {material.material_id: material.value for material in self.materials} + + +class ShapleyOutput(Resource): + """All feature effects for this output by feature.""" + + output = properties.String('output', serializable=False) + features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False) + + @property + def feature_dict(self) -> Dict[str, Dict[UUID, float]]: + """Presents the output's feature effects as a dictionary by feature.""" + return {feature.feature: feature.material_dict for feature in self.features} + + +class FeatureEffects(Resource): + """Captures information about the feature effects associated with a predictor.""" + + predictor_id = properties.UUID('metadata.predictor_id', serializable=False) + predictor_version = properties.Integer('metadata.predictor_version', serializable=False) + status = properties.String('metadata.status', serializable=False) + failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason', + serializable=False) + + outputs = properties.Optional(properties.List(properties.Object(ShapleyOutput)), 'resultobj', + serializable=False) + + @classmethod + def _pre_build(cls, data: dict) -> Dict: + shapley = data["result"] + material_ids = shapley["materials"] + + outputs = [] + for output, feature_dict in shapley["outputs"].items(): + features = [] + for feature, values in feature_dict.items(): + items = zip(material_ids, values) + materials = [{"material_id": mid, "value": value} for mid, value in items] + features.append({ + "feature": feature, + "materials": materials + }) + + outputs.append({"output": output, "features": features}) + + data["resultobj"] = outputs + return data + + @property + def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]: + """Presents the feature effects as a dictionary by output.""" + return {output.output: output.feature_dict for output in self.outputs} diff --git a/src/citrine/informatics/predictors/graph_predictor.py b/src/citrine/informatics/predictors/graph_predictor.py index 2e2cebdca..324a2fbb0 100644 --- a/src/citrine/informatics/predictors/graph_predictor.py +++ b/src/citrine/informatics/predictors/graph_predictor.py @@ -7,9 +7,11 @@ from citrine._session import Session from citrine._utils.functions import format_escaped_url from citrine.informatics.data_sources import DataSource +from citrine.informatics.feature_effects import FeatureEffects from citrine.informatics.predictors.single_predict_request import SinglePredictRequest from citrine.informatics.predictors.single_prediction import SinglePrediction from citrine.informatics.predictors import PredictorNode, Predictor +from citrine.informatics.reports import Report from citrine.resources.report import ReportResource __all__ = ['GraphPredictor'] @@ -104,7 +106,7 @@ def wrap_instance(predictor_data: dict) -> dict: } @property - def report(self): + def report(self) -> Report: """Fetch the predictor report.""" if self.uid is None or self._session is None or self._project_id is None \ or getattr(self, "version", None) is None: @@ -113,6 +115,13 @@ def report(self): report_resource = ReportResource(self._project_id, self._session) return report_resource.get(predictor_id=self.uid, predictor_version=self.version) + @property + def feature_effects(self) -> FeatureEffects: + """Retrieve the feature effects for all outputs in the predictor's training data..""" + path = self._path() + '/shapley/query' + response = self._session.post_resource(path, {}, version=self._api_version) + return FeatureEffects.build(response) + def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction: """Make a one-off prediction with this predictor.""" path = self._path() + '/predict' diff --git a/src/citrine/resources/table_config.py b/src/citrine/resources/table_config.py index 3a5726709..1421ef82e 100644 --- a/src/citrine/resources/table_config.py +++ b/src/citrine/resources/table_config.py @@ -88,6 +88,7 @@ class TableConfig(Resource["TableConfig"]): The query used to define the materials underpinning this table generation_algorithm: TableFromGemdQueryAlgorithm Which algorithm was used to generate the config based on the GemdQuery results + """ # FIXME (DML): rename this (this is dependent on the server side) diff --git a/tests/informatics/test_predictors.py b/tests/informatics/test_predictors.py index 0645fe4a7..570c96514 100644 --- a/tests/informatics/test_predictors.py +++ b/tests/informatics/test_predictors.py @@ -1,7 +1,8 @@ """Tests for citrine.informatics.predictors.""" -import uuid -import pytest import mock +import pytest +import uuid +from random import random from citrine.informatics.data_sources import GemTableDataSource from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \ @@ -12,6 +13,10 @@ from citrine.informatics.predictors.single_prediction import SinglePrediction from citrine.informatics.design_candidate import DesignMaterial +from tests.utils.factories import FeatureEffectsResponseFactory +from tests.utils.session import FakeCall, FakeSession + + w = IntegerDescriptor("w", lower_bound=0, upper_bound=100) x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") @@ -485,3 +490,21 @@ def test_single_predict(graph_predictor): prediction_out = graph_predictor.predict(request) assert prediction_out.dump() == prediction_in.dump() assert session.post_resource.call_count == 1 + + +def test_feature_effects(graph_predictor): + feature_effects_response = FeatureEffectsResponseFactory() + feature_effects_as_dict = feature_effects_response.pop("_result_as_dict") + + session = FakeSession() + session.set_response(feature_effects_response) + + graph_predictor._session = session + graph_predictor._project_id = uuid.uuid4() + + fe = graph_predictor.feature_effects + + expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \ + f"/versions/{graph_predictor.version}/shapley/query" + assert session.calls == [FakeCall(method='POST', path=expected_path, json={})] + assert fe.as_dict == feature_effects_as_dict diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 04da2e27a..83cf1fee4 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -859,3 +859,43 @@ class AnalysisWorkflowEntityDataFactory(factory.DictFactory): id = factory.Faker('uuid4') data = factory.SubFactory(AnalysisWorkflowDataDataFactory) metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory) + + +class FeatureEffectsResponseResultFactory(factory.DictFactory): + materials = factory.List([ + factory.Faker('uuid4', cast_to=None), + factory.Faker('uuid4', cast_to=None), + factory.Faker('uuid4', cast_to=None) + ]) + outputs = factory.Dict({ + "output1": factory.Dict({ + "feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]) + }), + "output2": factory.Dict({ + "feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]), + "feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]) + }) + }) + +class FeatureEffectsMetadataFactory(factory.DictFactory): + predictor_id = factory.Faker('uuid4') + predictor_version = factory.Faker('random_digit_not_null') + created = factory.SubFactory(UserTimestampDataFactory) + updated = factory.SubFactory(UserTimestampDataFactory) + status = 'SUCCEEDED' + + +class FeatureEffectsResponseFactory(factory.DictFactory): + query = {} # Presently, querying from the SDK is not allowed. + metadata = factory.SubFactory(FeatureEffectsMetadataFactory) + result = factory.SubFactory(FeatureEffectsResponseResultFactory) + _result_as_dict = factory.LazyAttribute(lambda obj: _expand_condensed(obj.result)) + + +def _expand_condensed(result_obj): + whole_dict = {} + for output, feature_dict in result_obj["outputs"].items(): + whole_dict[output] = {} + for feature, values in feature_dict.items(): + whole_dict[output][feature] = dict(zip(result_obj["materials"], values)) + return whole_dict