diff --git a/docs/examples/plot_lcppn_explainer.py b/docs/examples/plot_lcppn_explainer.py index d1290a46..fd048c3c 100644 --- a/docs/examples/plot_lcppn_explainer.py +++ b/docs/examples/plot_lcppn_explainer.py @@ -33,12 +33,11 @@ explanations = explainer.explain(X_test.values) print(explanations) -# Filter samples which only predicted "Respiratory" -respiratory_idx = explainer.get_sample_indices(predictions, "Respiratory") - # Use .sel() method to apply the filter and obtain filtered results shap_val_respiratory = explainer.filter_by_class( - explanations, class_name="Respiratory", sample_indices=respiratory_idx + explanations, + class_name="Respiratory", + sample_indices=explainer.get_sample_indices(predictions, "Respiratory"), )