Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shap.force_plot with the hiclass's explainer #130

Open
RamSnoussi opened this issue Aug 13, 2024 · 1 comment
Open

Use shap.force_plot with the hiclass's explainer #130

RamSnoussi opened this issue Aug 13, 2024 · 1 comment

Comments

@RamSnoussi
Copy link

Hi

  1. how can I use shap.force_plot with a current definition of explainer in order to obtain the following figure? The example below illustrates this use but cannot work correctly. How can I correct this example?
  2. shap.force_plot function have a base value as parameter. This parameter equals to explainer.expected_value
    how can I calculate expected_value with current definition of Explainer?

Capture d’écran du 2024-08-13 16-14-35

from hiclass import LocalClassifierPerParentNode, Explainer
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import shap
X_train = np.array([
    [40.7,  1. ,  1. ,  2. ,  5. ,  2. ,  1. ,  5. , 34.3],
    [39.2,  0. ,  2. ,  4. ,  1. ,  3. ,  1. ,  2. , 34.1],
    [40.6,  0. ,  3. ,  1. ,  4. ,  5. ,  0. ,  6. , 27.7],
    [36.5,  0. ,  3. ,  1. ,  2. ,  2. ,  0. ,  2. , 39.9],
])

Y_train = np.array([
    ['Gastrointestinal', 'Norovirus', ''],
    ['Respiratory', 'Covid', ''],
    ['Allergy', 'External', 'Bee Allergy'],
    ['Respiratory', 'Cold', ''],
])

test_sample = np.array([[35.5,  0. ,  1. ,  1. ,  3. ,  3. ,  0. ,  2. , 37.5]])

classifier = LocalClassifierPerParentNode(local_classifier=RandomForestClassifier(), replace_classifiers=False)
classifier.fit(X_train, Y_train)
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(test_sample)
mask = {'class': 'Cold', 'level':1}
shap_values = explanations.sel(mask).shap_values
shap.force_plot(shap_values, test_sample)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/IPython/core/formatters.py:344, in BaseFormatter.__call__(self, obj)
    342     method = get_real_method(obj, self.print_method)
    343     if method is not None:
--> 344         return method()
    345     return None
    346 else:

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/shap/plots/_force.py:531, in AdditiveForceVisualizer._repr_html_(self)
    530 def _repr_html_(self):
--> 531     return self.html()

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/shap/plots/_force.py:516, in AdditiveForceVisualizer.html(self, label_margin)
    510         self.data["labelMargin"] = label_margin
    511         generated_id = id_generator()
    512         return f"""
    513 <div id='{generated_id}'>{err_msg}</div>
    514  <script>
    515    if (window.SHAP) SHAP.ReactDom.render(
--> 516     SHAP.React.createElement(SHAP.AdditiveForceVisualizer, {json.dumps(self.data)}),
    517     document.getElementById('{generated_id}')
    518   );
    519 </script>"""

File ~/anaconda3/envs/hiclass/lib/python3.8/json/__init__.py:231, in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
    226 # cached encoder
    227 if (not skipkeys and ensure_ascii and
    228     check_circular and allow_nan and
    229     cls is None and indent is None and separators is None and
    230     default is None and not sort_keys and not kw):
--> 231     return _default_encoder.encode(obj)
    232 if cls is None:
    233     cls = JSONEncoder

File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:199, in JSONEncoder.encode(self, o)
    195         return encode_basestring(o)
    196 # This doesn't pass the iterator directly to ''.join() because the
    197 # exceptions aren't as detailed.  The list call should be roughly
    198 # equivalent to the PySequence_Fast that ''.join() would do.
--> 199 chunks = self.iterencode(o, _one_shot=True)
    200 if not isinstance(chunks, (list, tuple)):
    201     chunks = list(chunks)

File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:257, in JSONEncoder.iterencode(self, o, _one_shot)
    252 else:
    253     _iterencode = _make_iterencode(
    254         markers, self.default, _encoder, self.indent, floatstr,
    255         self.key_separator, self.item_separator, self.sort_keys,
    256         self.skipkeys, _one_shot)
--> 257 return _iterencode(o, 0)

File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:179, in JSONEncoder.default(self, o)
    160 def default(self, o):
    161     """Implement this method in a subclass such that it returns
    162     a serializable object for ``o``, or calls the base implementation
    163     (to raise a ``TypeError``).
   (...)
    177 
    178     """
--> 179     raise TypeError(f'Object of type {o.__class__.__name__} '
    180                     f'is not JSON serializable')

TypeError: Object of type DataArray is not JSON serializable
@RamSnoussi RamSnoussi changed the title Use force plot with hiclass's explainer Use force plot with the hiclass's explainer Aug 13, 2024
@RamSnoussi RamSnoussi changed the title Use force plot with the hiclass's explainer Use shap.force_plot with the hiclass's explainer Aug 13, 2024
@RamSnoussi
Copy link
Author

hi @mirand863,
Any suggestion how I can use shap.force_plot with a hiclass framework? Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant