Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An error occurred when using the Explainer of the last version #127

Open
RamSnoussi opened this issue Jul 13, 2024 · 7 comments
Open

An error occurred when using the Explainer of the last version #127

RamSnoussi opened this issue Jul 13, 2024 · 7 comments

Comments

@RamSnoussi
Copy link

hi @mirand863
What's the problem here? How can I correct this error?

from hiclass import LocalClassifierPerParentNode, Explainer
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import shap

X_train = np.array([
    [40.7,  1. ,  1. ,  2. ,  5. ,  2. ,  1. ,  5. , 34.3],
    [39.2,  0. ,  2. ,  4. ,  1. ,  3. ,  1. ,  2. , 34.1],
    [40.6,  0. ,  3. ,  1. ,  4. ,  5. ,  0. ,  6. , 27.7],
    [36.5,  0. ,  3. ,  1. ,  2. ,  2. ,  0. ,  2. , 39.9],
])

Y_train = np.array([
    ['Gastrointestinal', 'Norovirus', ''],
    ['Respiratory', 'Covid', ''],
    ['Allergy', 'External', 'Bee Allergy'],
    ['Respiratory', 'Cold', ''],
])

X_test = np.array([[35.5,  0. ,  1. ,  1. ,  3. ,  3. ,  0. ,  2. , 37.5]])


classifier = LocalClassifierPerParentNode(local_classifier=RandomForestClassifier())
classifier.fit(X_train, Y_train)
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test)
print(explanations)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 26
     24 classifier.fit(X_train, Y_train)
     25 explainer = Explainer(classifier, data=X_train, mode="tree")
---> 26 explanations = explainer.explain(X_test)
     27 print(explanations)

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/hiclass/Explainer.py:124, in Explainer.explain(self, X)
    117 check_array(X)
    119 if (
    120     isinstance(self.hierarchical_model, LocalClassifierPerParentNode)
    121     or isinstance(self.hierarchical_model, LocalClassifierPerLevel)
    122     or isinstance(self.hierarchical_model, LocalClassifierPerNode)
    123 ):
--> 124     return self._explain_with_xr(X)
    125 else:
    126     raise ValueError(f"Invalid model: {self.hierarchical_model}.")

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/hiclass/Explainer.py:142, in Explainer._explain_with_xr(self, X)
    128 def _explain_with_xr(self, X):
    129     """
    130     Generate SHAP values for each node using the SHAP package.
    131 
   (...)
    140         An xarray Dataset consisting of SHAP values for each sample.
    141     """
--> 142     explanations = Parallel(n_jobs=self.n_jobs, backend="threading")(
    143         delayed(self._calculate_shap_values)(sample.reshape(1, -1)) for sample in X
    144     )
    146     dataset = xr.concat(explanations, dim="sample")
    147     return dataset

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/joblib/parallel.py:1918, in Parallel.__call__(self, iterable)
   1916     output = self._get_sequential_output(iterable)
   1917     next(output)
-> 1918     return output if self.return_generator else list(output)
   1920 # Let's create an ID that uniquely identifies the current call. If the
   1921 # call is interrupted early and that the same instance is immediately
   1922 # re-used, this id will be used to prevent workers that were
   1923 # concurrently finalizing a task from the previous call to run the
   1924 # callback.
   1925 with self._lock:

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/joblib/parallel.py:1847, in Parallel._get_sequential_output(self, iterable)
   1845 self.n_dispatched_batches += 1
   1846 self.n_dispatched_tasks += 1
-> 1847 res = func(*args, **kwargs)
   1848 self.n_completed_tasks += 1
   1849 self.print_progress()

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/hiclass/Explainer.py:275, in Explainer._calculate_shap_values(self, X)
    273     traversed_nodes = self._get_traversed_nodes_lcpl(X)[0]
    274 elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
--> 275     traversed_nodes = self._get_traversed_nodes_lcppn(X)[0]
    276 elif isinstance(self.hierarchical_model, LocalClassifierPerNode):
    277     traversed_nodes = self._get_traversed_nodes_lcpn(X)[0]

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/hiclass/Explainer.py:170, in Explainer._get_traversed_nodes_lcppn(self, samples)
    164 traversals = np.empty(
    165     (samples.shape[0], self.hierarchical_model.max_levels_),
    166     dtype=self.hierarchical_model.dtype_,
    167 )
    169 # Initialize first element as root node
--> 170 traversals[:, 0] = self.hierarchical_model.root_
    172 # For subsequent nodes, calculate mask and find predictions
    173 for level in range(1, traversals.shape[1]):

ValueError: invalid literal for int() with base 10: 'hiclass::root'
@RamSnoussi
Copy link
Author

Hi,
Because encoding the type of y_ become from string to int64 (line 223 HierarchicalClassifier.py). Then the type of traversals (line 164 Explainer.py) is int64. However self.hierarchical_model.root_ (line 170 Explainer.py) is string. Then the file Explainer.py should be updated.

@RamSnoussi
Copy link
Author

Hi @mirand863,
Do you have any suggestions about this issue?

@mirand863
Copy link
Collaborator

mirand863 commented Jul 17, 2024

Hi @RamSnoussi,

Can you please explain a bit how you are using the explainer? Are you using with encoded labels related to the other issue we were discussing previously so the root should also be an integer? Is that correct?

@RamSnoussi
Copy link
Author

Hi @mirand863,
I'm using the hiclass version v4.10.0 (https://github.com/scikit-learn-contrib/hiclass). Look please the example above and the error generated (traversals[:, 0] = self.hierarchical_model.root_ in line 170 Explainer.py). This error is due because traversals[:,0] is an Integer and self.hierarchical_model.root_ is a string. You modified the version of hiclass by adding the encoder (in file HierarchicalClassifier.py) but Explainer.py shoud be updated too.

@RamSnoussi
Copy link
Author

This attacked file (HierarchicalClassifier.py) is in my execution environment when you used Encoder in line 221. However, the encoder has been deleted in github's version (https://github.com/scikit-learn-contrib/hiclass/blob/main/hiclass/HierarchicalClassifier.py) but why? which released version can I use?
HierarchicalClassifier.txt

@mirand863
Copy link
Collaborator

mirand863 commented Jul 26, 2024

This attacked file (HierarchicalClassifier.py) is in my execution environment when you used Encoder in line 221. However, the encoder has been deleted in github's version (https://github.com/scikit-learn-contrib/hiclass/blob/main/hiclass/HierarchicalClassifier.py) but why? which released version can I use? HierarchicalClassifier.txt

I see now. The encoder has not been released yet, but is only available in this branch called cuml main...cuml

I was just testing it out and never actually released. I also did not need the explainer in my use case, so I did not update that file. I might be able to do it in the next days if it is important for your use case. I can add tests and try to make it run without bugs with a proper release. :)

@RamSnoussi
Copy link
Author

Hi @mirand863,
Have you updated the explainer's file that corresponds to encoder add ?Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants