Open
Description
Hi
- how can I use shap.force_plot with a current definition of explainer in order to obtain the following figure? The example below illustrates this use but cannot work correctly. How can I correct this example?
- shap.force_plot function have a base value as parameter. This parameter equals to explainer.expected_value
how can I calculate expected_value with current definition of Explainer?
from hiclass import LocalClassifierPerParentNode, Explainer
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import shap
X_train = np.array([
[40.7, 1. , 1. , 2. , 5. , 2. , 1. , 5. , 34.3],
[39.2, 0. , 2. , 4. , 1. , 3. , 1. , 2. , 34.1],
[40.6, 0. , 3. , 1. , 4. , 5. , 0. , 6. , 27.7],
[36.5, 0. , 3. , 1. , 2. , 2. , 0. , 2. , 39.9],
])
Y_train = np.array([
['Gastrointestinal', 'Norovirus', ''],
['Respiratory', 'Covid', ''],
['Allergy', 'External', 'Bee Allergy'],
['Respiratory', 'Cold', ''],
])
test_sample = np.array([[35.5, 0. , 1. , 1. , 3. , 3. , 0. , 2. , 37.5]])
classifier = LocalClassifierPerParentNode(local_classifier=RandomForestClassifier(), replace_classifiers=False)
classifier.fit(X_train, Y_train)
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(test_sample)
mask = {'class': 'Cold', 'level':1}
shap_values = explanations.sel(mask).shap_values
shap.force_plot(shap_values, test_sample)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/IPython/core/formatters.py:344, in BaseFormatter.__call__(self, obj)
342 method = get_real_method(obj, self.print_method)
343 if method is not None:
--> 344 return method()
345 return None
346 else:
File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/shap/plots/_force.py:531, in AdditiveForceVisualizer._repr_html_(self)
530 def _repr_html_(self):
--> 531 return self.html()
File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/shap/plots/_force.py:516, in AdditiveForceVisualizer.html(self, label_margin)
510 self.data["labelMargin"] = label_margin
511 generated_id = id_generator()
512 return f"""
513 <div id='{generated_id}'>{err_msg}</div>
514 <script>
515 if (window.SHAP) SHAP.ReactDom.render(
--> 516 SHAP.React.createElement(SHAP.AdditiveForceVisualizer, {json.dumps(self.data)}),
517 document.getElementById('{generated_id}')
518 );
519 </script>"""
File ~/anaconda3/envs/hiclass/lib/python3.8/json/__init__.py:231, in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
226 # cached encoder
227 if (not skipkeys and ensure_ascii and
228 check_circular and allow_nan and
229 cls is None and indent is None and separators is None and
230 default is None and not sort_keys and not kw):
--> 231 return _default_encoder.encode(obj)
232 if cls is None:
233 cls = JSONEncoder
File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:199, in JSONEncoder.encode(self, o)
195 return encode_basestring(o)
196 # This doesn't pass the iterator directly to ''.join() because the
197 # exceptions aren't as detailed. The list call should be roughly
198 # equivalent to the PySequence_Fast that ''.join() would do.
--> 199 chunks = self.iterencode(o, _one_shot=True)
200 if not isinstance(chunks, (list, tuple)):
201 chunks = list(chunks)
File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:257, in JSONEncoder.iterencode(self, o, _one_shot)
252 else:
253 _iterencode = _make_iterencode(
254 markers, self.default, _encoder, self.indent, floatstr,
255 self.key_separator, self.item_separator, self.sort_keys,
256 self.skipkeys, _one_shot)
--> 257 return _iterencode(o, 0)
File ~/anaconda3/envs/hiclass/lib/python3.8/json/encoder.py:179, in JSONEncoder.default(self, o)
160 def default(self, o):
161 """Implement this method in a subclass such that it returns
162 a serializable object for ``o``, or calls the base implementation
163 (to raise a ``TypeError``).
(...)
177
178 """
--> 179 raise TypeError(f'Object of type {o.__class__.__name__} '
180 f'is not JSON serializable')
TypeError: Object of type DataArray is not JSON serializable
Metadata
Metadata
Assignees
Labels
No labels