Skip to content

Use shap.force_plot with the hiclass's explainer #130

Open
@RamSnoussi

Description

@RamSnoussi

Hi

  1. how can I use shap.force_plot with a current definition of explainer in order to obtain the following figure? The example below illustrates this use but cannot work correctly. How can I correct this example?
  2. shap.force_plot function have a base value as parameter. This parameter equals to explainer.expected_value
    how can I calculate expected_value with current definition of Explainer?

Capture d’écran du 2024-08-13 16-14-35

from hiclass import LocalClassifierPerParentNode, Explainer
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import shap
X_train = np.array([
    [40.7,  1. ,  1. ,  2. ,  5. ,  2. ,  1. ,  5. , 34.3],
    [39.2,  0. ,  2. ,  4. ,  1. ,  3. ,  1. ,  2. , 34.1],
    [40.6,  0. ,  3. ,  1. ,  4. ,  5. ,  0. ,  6. , 27.7],
    [36.5,  0. ,  3. ,  1. ,  2. ,  2. ,  0. ,  2. , 39.9],
])

Y_train = np.array([
    ['Gastrointestinal', 'Norovirus', ''],
    ['Respiratory', 'Covid', ''],
    ['Allergy', 'External', 'Bee Allergy'],
    ['Respiratory', 'Cold', ''],
])

test_sample = np.array([[35.5,  0. ,  1. ,  1. ,  3. ,  3. ,  0. ,  2. , 37.5]])

classifier = LocalClassifierPerParentNode(local_classifier=RandomForestClassifier(), replace_classifiers=False)
classifier.fit(X_train, Y_train)
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(test_sample)
mask = {'class': 'Cold', 'level':1}
shap_values = explanations.sel(mask).shap_values
shap.force_plot(shap_values, test_sample)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/IPython/core/formatters.py:344, in BaseFormatter.__call__(self, obj)
    342     method = get_real_method(obj, self.print_method)
    343     if method is not None:
--> 344         return method()
    345     return None
    346 else:

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/shap/plots/_force.py:531, in AdditiveForceVisualizer._repr_html_(self)
    530 def _repr_html_(self):
--> 531     return self.html()

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/shap/plots/_force.py:516, in AdditiveForceVisualizer.html(self, label_margin)
    510         self.data["labelMargin"] = label_margin
    511         generated_id = id_generator()
    512         return f"""
    513 <div id='{generated_id}'>{err_msg}</div>
    514  <script>
    515    if (window.SHAP) SHAP.ReactDom.render(
--> 516     SHAP.React.createElement(SHAP.AdditiveForceVisualizer, {json.dumps(self.data)}),
    517     document.getElementById('{generated_id}')
    518   );
    519 </script>"""

File ~/anaconda3/envs/hiclass/lib/python3.8/json/__init__.py:231, in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
    226 # cached encoder
    227 if (not skipkeys and ensure_ascii and
    228     check_circular and allow_nan and
    229     cls is None and indent is None and separators is None and
    230     default is None and not sort_keys and not kw):
--> 231     return _default_encoder.encode(obj)
    232 if cls is None:
    233     cls = JSONEncoder

File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:199, in JSONEncoder.encode(self, o)
    195         return encode_basestring(o)
    196 # This doesn't pass the iterator directly to ''.join() because the
    197 # exceptions aren't as detailed.  The list call should be roughly
    198 # equivalent to the PySequence_Fast that ''.join() would do.
--> 199 chunks = self.iterencode(o, _one_shot=True)
    200 if not isinstance(chunks, (list, tuple)):
    201     chunks = list(chunks)

File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:257, in JSONEncoder.iterencode(self, o, _one_shot)
    252 else:
    253     _iterencode = _make_iterencode(
    254         markers, self.default, _encoder, self.indent, floatstr,
    255         self.key_separator, self.item_separator, self.sort_keys,
    256         self.skipkeys, _one_shot)
--> 257 return _iterencode(o, 0)

File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:179, in JSONEncoder.default(self, o)
    160 def default(self, o):
    161     """Implement this method in a subclass such that it returns
    162     a serializable object for ``o``, or calls the base implementation
    163     (to raise a ``TypeError``).
   (...)
    177 
    178     """
--> 179     raise TypeError(f'Object of type {o.__class__.__name__} '
    180                     f'is not JSON serializable')

TypeError: Object of type DataArray is not JSON serializable

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions