Skip to content

Commit

Permalink
fixed imports
Browse files Browse the repository at this point in the history
  • Loading branch information
technocreep committed Dec 20, 2023
1 parent 23ad9fb commit 12d6363
Showing 1 changed file with 16 additions and 36 deletions.
52 changes: 16 additions & 36 deletions fedot_ind/tools/explain/explain.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math

import lime
import lime.lime_tabular
# import lime
# import lime.lime_tabular
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
# import shap
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from tqdm import tqdm
Expand Down Expand Up @@ -162,36 +162,16 @@ def plot_importance(self, thr=90, name='dataset'):
plt.show()


class ShapExplainer:
def __init__(self, model, features, target, prediction):
self.model = model
self.features = features
self.target = target
self.prediction = prediction

def explain(self, n_samples: int = 5):
X_test = self.features

explainer = shap.KernelExplainer(self.model.predict, X_test, n_samples=n_samples)
shap_values = explainer.shap_values(X_test.iloc[:n_samples, :])
shap.summary_plot(shap_values, X_test.iloc[:n_samples, :], plot_type="bar")


class LimeExplainer:
def __init__(self, model, train_features, test_features, target, prediction):
self.model = model
self.train_features = train_features
self.test_features = test_features
self.target = target
self.prediction = prediction

def explain(self, n_samples):
explainer = lime.lime_tabular.LimeTabularExplainer(training_data=self.train_features.values,
feature_names=self.train_features.columns,
class_names=self.target,
discretize_continuous=True)
i = np.random.randint(0, self.test_features.shape[0])
exp = explainer.explain_instance(data_row=self.test_features.iloc[i, :].values,
predict_fn=self.model.predict_proba,
num_features=10)
exp.show_in_notebook(show_table=True, show_all=False)
# class ShapExplainer:
# def __init__(self, model, features, target, prediction):
# self.model = model
# self.features = features
# self.target = target
# self.prediction = prediction
#
# def explain(self, n_samples: int = 5):
# X_test = self.features
#
# explainer = shap.KernelExplainer(self.model.predict, X_test, n_samples=n_samples)
# shap_values = explainer.shap_values(X_test.iloc[:n_samples, :])
# shap.summary_plot(shap_values, X_test.iloc[:n_samples, :], plot_type="bar")

0 comments on commit 12d6363

Please sign in to comment.