You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
how can I use shap.force_plot with a current definition of explainer in order to obtain the following figure? The example below illustrates this use but cannot work correctly. How can I correct this example?
shap.force_plot function have a base value as parameter. This parameter equals to explainer.expected_value
how can I calculate expected_value with current definition of Explainer?
from hiclass import LocalClassifierPerParentNode, Explainer
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import shap
X_train = np.array([
[40.7, 1. , 1. , 2. , 5. , 2. , 1. , 5. , 34.3],
[39.2, 0. , 2. , 4. , 1. , 3. , 1. , 2. , 34.1],
[40.6, 0. , 3. , 1. , 4. , 5. , 0. , 6. , 27.7],
[36.5, 0. , 3. , 1. , 2. , 2. , 0. , 2. , 39.9],
])
Y_train = np.array([
['Gastrointestinal', 'Norovirus', ''],
['Respiratory', 'Covid', ''],
['Allergy', 'External', 'Bee Allergy'],
['Respiratory', 'Cold', ''],
])
test_sample = np.array([[35.5, 0. , 1. , 1. , 3. , 3. , 0. , 2. , 37.5]])
classifier = LocalClassifierPerParentNode(local_classifier=RandomForestClassifier(), replace_classifiers=False)
classifier.fit(X_train, Y_train)
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(test_sample)
mask = {'class': 'Cold', 'level':1}
shap_values = explanations.sel(mask).shap_values
shap.force_plot(shap_values, test_sample)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/IPython/core/formatters.py:344, in BaseFormatter.__call__(self, obj)
342 method = get_real_method(obj, self.print_method)
343 if method is not None:
--> 344 return method()
345 return None
346 else:
File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/shap/plots/_force.py:531, in AdditiveForceVisualizer._repr_html_(self)
530 def _repr_html_(self):
--> 531 return self.html()
File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/shap/plots/_force.py:516, in AdditiveForceVisualizer.html(self, label_margin)
510 self.data["labelMargin"] = label_margin
511 generated_id = id_generator()
512 return f"""
513 <div id='{generated_id}'>{err_msg}</div>
514 <script>
515 if (window.SHAP) SHAP.ReactDom.render(
--> 516 SHAP.React.createElement(SHAP.AdditiveForceVisualizer, {json.dumps(self.data)}),
517 document.getElementById('{generated_id}')
518 );
519 </script>"""
File ~/anaconda3/envs/hiclass/lib/python3.8/json/__init__.py:231, in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
226 # cached encoder
227 if (not skipkeys and ensure_ascii and
228 check_circular and allow_nan and
229 cls is None and indent is None and separators is None and
230 default is None and not sort_keys and not kw):
--> 231 return _default_encoder.encode(obj)
232 if cls is None:
233 cls = JSONEncoder
File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:199, in JSONEncoder.encode(self, o)
195 return encode_basestring(o)
196 # This doesn't pass the iterator directly to ''.join() because the
197 # exceptions aren't as detailed. The list call should be roughly
198 # equivalent to the PySequence_Fast that ''.join() would do.
--> 199 chunks = self.iterencode(o, _one_shot=True)
200 if not isinstance(chunks, (list, tuple)):
201 chunks = list(chunks)
File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:257, in JSONEncoder.iterencode(self, o, _one_shot)
252 else:
253 _iterencode = _make_iterencode(
254 markers, self.default, _encoder, self.indent, floatstr,
255 self.key_separator, self.item_separator, self.sort_keys,
256 self.skipkeys, _one_shot)
--> 257 return _iterencode(o, 0)
File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:179, in JSONEncoder.default(self, o)
160 def default(self, o):
161 """Implement this method in a subclass such that it returns
162 a serializable object for ``o``, or calls the base implementation
163 (to raise a ``TypeError``).
(...)
177
178 """
--> 179 raise TypeError(f'Object of type {o.__class__.__name__} '
180 f'is not JSON serializable')
TypeError: Object of type DataArray is not JSON serializable
The text was updated successfully, but these errors were encountered:
RamSnoussi
changed the title
Use force plot with hiclass's explainer
Use force plot with the hiclass's explainer
Aug 13, 2024
RamSnoussi
changed the title
Use force plot with the hiclass's explainer
Use shap.force_plot with the hiclass's explainer
Aug 13, 2024
Hi
how can I calculate expected_value with current definition of Explainer?
The text was updated successfully, but these errors were encountered: