Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto choose most appropriate explainable model #355

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 129 additions & 11 deletions python/interpret_community/mimic/mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from ..common.blackbox_explainer import BlackBoxExplainer

from .model_distill import _model_distill
from .models import LGBMExplainableModel
from .models import LGBMExplainableModel, LinearExplainableModel, SGDExplainableModel, \
DecisionTreeExplainableModel
from ..explanation.explanation import _create_local_explanation, _create_global_explanation, \
_aggregate_global_from_local_explanation, _aggregate_streamed_local_explanations, \
_create_raw_feats_global_explanation, _create_raw_feats_local_explanation, \
Expand Down Expand Up @@ -133,14 +134,19 @@ class MimicExplainer(BlackBoxExplainer):
:param reset_index: Uses the pandas DataFrame index column as part of the features when training
the surrogate model.
:type reset_index: str
:param auto_select_explainable_model: Set this to 'True' if you want to use the MimicExplainer with an
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be a separate explainer or function - mimic explainer takes a specific surrogate model and not a list. This also seems like something that complicates mimic explainer logic. Maybe we can discuss more.

Thinking of other libraries, usually there is a distinction between hyperparameter tuning and training (eg in both v1 studio and designer there is a Train Model and Tune Hyperparameters or Cross validate module, in spark ML the hyperparameter tuner is a separate estimator, in scikit-learn similarly grid search cv is a separate function). I feel like for users who want to do this we should have a separate function/class instead of complicating the current mimic explainer.

auto-selected explainable model. We train four different explainable models LGBMExplainableModel,
LinearExplainableModel, SGDExplainableModel and DecisionTreeExplainableModel and score them to find
the best explainable model. This model is then used to derive explanations.
:type auto_select_explainable_model: bool
"""

@init_tabular_decorator
def __init__(self, model, initialization_examples, explainable_model, explainable_model_args=None,
is_function=False, augment_data=True, max_num_of_augmentations=10, explain_subset=None,
features=None, classes=None, transformations=None, allow_all_transformations=False,
shap_values_output=ShapValuesOutput.DEFAULT, categorical_features=None,
model_task=ModelTask.Unknown, reset_index=ResetIndex.Ignore, **kwargs):
model_task=ModelTask.Unknown, reset_index=ResetIndex.Ignore,
auto_select_explainable_model=False, **kwargs):
"""Initialize the MimicExplainer.

:param model: The black box model or function (if is_function is True) to be explained. Also known
Expand Down Expand Up @@ -233,6 +239,11 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
the index when calling predict on the original model. Only use reset_teacher if the index is already
featurized as part of the data.
:type reset_index: str
:param auto_select_explainable_model: Set this to 'True' if you want to use the MimicExplainer with an
auto-selected explainable model. We train four different explainable models LGBMExplainableModel,
LinearExplainableModel, SGDExplainableModel and DecisionTreeExplainableModel and score them to find
the best explainable model. This model is then used to derive explanations.
:type auto_select_explainable_model: bool
"""
if transformations is not None and explain_subset is not None:
raise ValueError("explain_subset not supported with transformations")
Expand All @@ -250,8 +261,7 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
wrapped_model, eval_ml_domain = _wrap_model(model, initialization_examples, model_task, is_function)
super(MimicExplainer, self).__init__(wrapped_model, is_function=is_function,
model_task=eval_ml_domain, **kwargs)
if explainable_model_args is None:
explainable_model_args = {}

if categorical_features is None:
categorical_features = []
self._logger.debug('Initializing MimicExplainer')
Expand Down Expand Up @@ -288,7 +298,6 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
# Index the categorical string columns for training data
self._column_indexer = initialization_examples.string_index(columns=categorical_features)
self._one_hot_encoder = None
explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features
else:
# One-hot-encode categoricals for models that don't support categoricals natively
self._column_indexer = initialization_examples.string_index(columns=categorical_features)
Expand All @@ -304,14 +313,86 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl
if isinstance(training_data, DenseData):
training_data = training_data.data

self._original_eval_examples = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is quite a bit of logic to put inside mimic explainer, I'm really wondering how we could simplify this as mimic explainer is already quite complicated

self._allow_all_transformations = allow_all_transformations

if auto_select_explainable_model:
# Train all available surrogate models to find the respective replication scores
explainable_model_list = [LGBMExplainableModel, LinearExplainableModel,
SGDExplainableModel, DecisionTreeExplainableModel]
self._best_replication_score = None
self._all_replication_scores = {}
for some_explainable_model in explainable_model_list:
try:
# Set params for explainable model
some_args = self._supplement_explainable_model_args(
explainable_model=some_explainable_model,
explainable_model_args={},
categorical_features=categorical_features,
shap_values_output=shap_values_output)
# Train the explainable model
surrogate_model = _model_distill(self.function, some_explainable_model, training_data,
original_training_data, some_args)
# Compute the replication score between the teacher model and surrogate model
surrogate_replication_score = self._get_surrogate_model_replication_measure(
training_data=training_data,
surrogate_model=surrogate_model)
# Store the replication score
self._all_replication_scores[surrogate_model.method] = surrogate_replication_score

# Keep track of the best score and the best trained surrogate model
if self._best_replication_score is None or \
surrogate_replication_score > self._best_replication_score:
self.surrogate_model = surrogate_model
self._best_replication_score = surrogate_replication_score
except Exception:
pass

if not auto_select_explainable_model or \
(hasattr(self, "_best_replication_score") and self._best_replication_score is None):
# If the training/scoring of explainable model fails for some reason,
# then fall back on the user specified explainable model and train it.
explainable_model_args = self._supplement_explainable_model_args(
explainable_model=explainable_model,
explainable_model_args=explainable_model_args,
categorical_features=categorical_features,
shap_values_output=shap_values_output)

self.surrogate_model = _model_distill(
self.function, explainable_model, training_data,
original_training_data, explainable_model_args)

try:
surrogate_replication_score = None
# Compute the replication score between the teacher model and surrogate model
surrogate_replication_score = self._get_surrogate_model_replication_measure(
training_data=training_data,
surrogate_model=self.surrogate_model)
except Exception:
pass
finally:
# Store the replication score
self._all_replication_scores = {}
self._all_replication_scores[self.surrogate_model.method] = surrogate_replication_score
self._best_replication_score = surrogate_replication_score

self._method = self.surrogate_model.method

def _supplement_explainable_model_args(self, explainable_model, explainable_model_args,
categorical_features, shap_values_output):
if explainable_model_args is None:
explainable_model_args = {}

if explainable_model.explainable_model_type == ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE and \
self._supports_categoricals(explainable_model):
explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features

explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag

if self._supports_shap_values_output(explainable_model):
explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output
self.surrogate_model = _model_distill(self.function, explainable_model, training_data,
original_training_data, explainable_model_args)
self._method = self.surrogate_model._method
self._original_eval_examples = None
self._allow_all_transformations = allow_all_transformations

return explainable_model_args

def _supports_categoricals(self, explainable_model):
return issubclass(explainable_model, LGBMExplainableModel)
Expand Down Expand Up @@ -630,6 +711,43 @@ def _load(model, properties):
mimic.__dict__[MimicSerializationConstants.ALLOW_ALL_TRANSFORMATIONS] = False
return mimic

def _get_surrogate_model_replication_measure(self, training_data, surrogate_model):
"""Return the metric which tells how well the surrogate model replicates the teacher model.

:param training_data: The data for getting the replication metric.
:type training_data: numpy.array or pandas.DataFrame or iml.datatypes.DenseData or
scipy.sparse.csr_matrix
:param surrogate_model: Trained surrogate model.
:type surrogate_model: Any
:return: Metric that tells how well the surrogate model replicates the behavior of teacher model.
:rtype: float
"""
try:
from sklearn.metrics import accuracy_score
from sklearn.metrics import r2_score
sklearn_metrics_available = True
except ImportError:
sklearn_metrics_available = False

if not sklearn_metrics_available:
raise Exception(
"Cannot compute replication metrics due to missing sklearn metrics package")

surrogate_model_predictions = surrogate_model.predict(training_data)
teacher_model_predictions = self.model.predict(training_data)

if self.classes is not None:
if len(self.classes) > 2:
replication_measure = accuracy_score(teacher_model_predictions, surrogate_model_predictions)
else:
raise Exception("Replication measure is not supported for binary classification")
else:
if training_data.shape[0] == 1:
raise Exception("Replication measure for regression surrogate not supported "
"because of single instance in training data")
replication_measure = r2_score(teacher_model_predictions, surrogate_model_predictions)
return replication_measure

def __getstate__(self):
"""Influence how MimicExplainer is pickled.

Expand Down
5 changes: 5 additions & 0 deletions python/interpret_community/mimic/models/explainable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def explainable_model_type(self):
"""Retrieve the model type."""
pass

@property
def method(self):
"""Return the name of the explainable model."""
return self._method

def __getstate__(self):
"""Influence how SGDExplainableModel is pickled.

Expand Down
64 changes: 58 additions & 6 deletions test/test_mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,13 @@ def test_explain_model_string_classes(self, mimic_explainer):
transformations=feat_pipe)
global_explanation = explainer.explain_global(X.iloc[:1000])
assert global_explanation.method == LINEAR_METHOD
assert explainer._all_replication_scores is not None
assert 'linear' in explainer._all_replication_scores
assert explainer._all_replication_scores['linear'] is None
assert explainer._best_replication_score is None

def test_linear_explainable_model_regression(self, mimic_explainer):
@pytest.mark.parametrize('auto_select_explainable_model', [True, False])
def test_linear_explainable_model_regression(self, mimic_explainer, auto_select_explainable_model):
num_features = 3
x_train = np.array([['a', 'E', 'x'], ['c', 'D', 'y']])
y_train = np.array([1, 2])
Expand All @@ -464,19 +469,42 @@ def test_linear_explainable_model_regression(self, mimic_explainer):
explainable_model = LinearExplainableModel
explainer = mimic_explainer(model.named_steps['regressor'], x_train, explainable_model,
transformations=transformations, augment_data=False,
auto_select_explainable_model=auto_select_explainable_model,
explainable_model_args={'sparse_data': True}, features=['f1', 'f2', 'f3'])
global_explanation = explainer.explain_global(x_train)
assert global_explanation.method == LINEAR_METHOD

def test_linear_explainable_model_classification(self, mimic_explainer):
assert explainer._all_replication_scores is not None
assert explainer._best_replication_score is not None
if not auto_select_explainable_model:
assert 'linear' in explainer._all_replication_scores
assert explainer._all_replication_scores['linear'] is not None
else:
assert 'linear' in explainer._all_replication_scores
assert explainer._all_replication_scores['linear'] is not None
assert 'sgd' in explainer._all_replication_scores
assert explainer._all_replication_scores['sgd'] is not None
assert 'lightgbm' in explainer._all_replication_scores
assert explainer._all_replication_scores['lightgbm'] is not None
assert 'tree' in explainer._all_replication_scores
assert explainer._all_replication_scores['tree'] is not None

@pytest.mark.parametrize('if_multiclass', [True, False])
@pytest.mark.parametrize('auto_select_explainable_model', [True, False])
def test_linear_explainable_model_classification(self, mimic_explainer, if_multiclass,
auto_select_explainable_model):
n_samples = 100
n_cat_features = 15

cat_feature_names = [f'cat_feature_{i}' for i in range(n_cat_features)]
cat_features = np.random.choice(['a', 'b', 'c', 'd'], (n_samples, n_cat_features))

data_x = pd.DataFrame(cat_features, columns=cat_feature_names)
data_y = np.random.choice(['0', '1'], n_samples)
if if_multiclass:
data_y = np.random.choice([0, 1, 2, 3], n_samples)
classes = [0, 1, 2, 3]
else:
data_y = np.random.choice([0, 1], n_samples)
classes = [0, 1]

# prepare feature encoders
cat_feature_encoders = [OneHotEncoder().fit(cat_features[:, i].reshape(-1, 1)) for i in range(n_cat_features)]
Expand All @@ -498,11 +526,35 @@ def test_linear_explainable_model_classification(self, mimic_explainer):
explainable_model_args={'sparse_data': True},
augment_data=False,
features=cat_feature_names,
classes=['0', '1'],
classes=classes,
auto_select_explainable_model=auto_select_explainable_model,
transformations=cat_transformations,
model_task=ModelTask.Classification)
global_explanation = explainer.explain_global(evaluation_examples=data_x)
assert global_explanation.method == LINEAR_METHOD

if if_multiclass:
assert explainer._all_replication_scores is not None
assert explainer._best_replication_score is not None
if not auto_select_explainable_model:
assert global_explanation.method == LINEAR_METHOD
assert 'linear' in explainer._all_replication_scores
assert explainer._all_replication_scores['linear'] is not None
else:
assert global_explanation.method == LIGHTGBM_METHOD
assert 'linear' in explainer._all_replication_scores
assert explainer._all_replication_scores['linear'] is not None
assert 'sgd' in explainer._all_replication_scores
assert explainer._all_replication_scores['sgd'] is not None
assert 'lightgbm' in explainer._all_replication_scores
assert explainer._all_replication_scores['lightgbm'] is not None
assert 'tree' in explainer._all_replication_scores
assert explainer._all_replication_scores['tree'] is not None
else:
assert global_explanation.method == LINEAR_METHOD
assert explainer._all_replication_scores is not None
assert 'linear' in explainer._all_replication_scores
assert explainer._all_replication_scores['linear'] is None
assert explainer._best_replication_score is None

def test_dense_wide_data(self, mimic_explainer):
# use 6000 rows instead for real performance testing
Expand Down