Skip to content

Commit

Permalink
Fix model wrapper, remove task type from error analysis (microsoft#1912)
Browse files Browse the repository at this point in the history
  • Loading branch information
tongyu-microsoft authored Jan 17, 2023
1 parent 5f6dd92 commit 9d24ea2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 19 deletions.
71 changes: 54 additions & 17 deletions responsibleai/responsibleai/managers/error_analysis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from erroranalysis._internal.error_report import \
json_converter as report_json_converter
from responsibleai._config.base_config import BaseConfig
from responsibleai._interfaces import ErrorAnalysisData, TaskType
from responsibleai._interfaces import ErrorAnalysisData
from responsibleai._internal.constants import ErrorAnalysisManagerKeys as Keys
from responsibleai._internal.constants import ListProperties, ManagerNames
from responsibleai._tools.shared.state_directory_management import \
Expand Down Expand Up @@ -82,13 +82,26 @@ def as_error_config(json_dict):
return json_dict


class MetadataRemovalModelWrapper():
"""Defines MetadataRemovalModelWrapper, wrapping the model
to ignore dropped feature metadata if any."""
def get_wrapped_model(model, dropped_features):
predict_proba_flag = hasattr(model, 'predict_proba')
if predict_proba_flag:
wrapper_model = MetadataRemovalClassificationModelWrapper(
model,
dropped_features)
else:
wrapper_model = MetadataRemovalRegressionModelWrapper(
model,
dropped_features)
return wrapper_model


class MetadataRemovalClassificationModelWrapper():
"""Defines MetadataRemovalClassificationModelWrapper, wrapping the
classification model to ignore dropped feature metadata if any."""

def __init__(self, model: any,
dropped_features: Optional[List[str]] = None):
"""If needed, wraps the model to ignore the dropped features.
"""If needed, wraps the classification model to ignore the dropped features.
:param model: The model or function to evaluate on the examples.
:type model: function or model with a predict or predict_proba function
Expand All @@ -111,6 +124,30 @@ def _apply_func(self, func, dataset):
return func(dataset.drop(columns=self.dropped_features, axis=1))


class MetadataRemovalRegressionModelWrapper():
"""Defines MetadataRemovalRegressionModelWrapper, wrapping the
regression model to ignore dropped feature metadata if any."""

def __init__(self, model: any,
dropped_features: Optional[List[str]] = None):
"""If needed, wraps the model to ignore the dropped features.
:param model: The model or function to evaluate on the examples.
:type model: function or model with a predict function
:param dropped_features: List of features that were dropped by the
the user during training of their model.
:type dropped_features: Optional[List[str]]
"""
self.model = model
self.dropped_features = dropped_features

def predict(self, dataset: pd.DataFrame):
if self.dropped_features is None or len(self.dropped_features) == 0:
return self.model.predict(dataset)
return self.model.predict(dataset.drop(
columns=self.dropped_features, axis=1))


class ErrorAnalysisConfig(BaseConfig):
"""Defines the ErrorAnalysisConfig, specifying the parameters to run."""

Expand Down Expand Up @@ -188,8 +225,7 @@ class ErrorAnalysisManager(BaseManager):
def __init__(self, model: Any, dataset: pd.DataFrame, target_column: str,
classes: Optional[List] = None,
categorical_features: Optional[List[str]] = None,
dropped_features: Optional[List[str]] = None,
task_type: Optional[TaskType] = None):
dropped_features: Optional[List[str]] = None):
"""Creates an ErrorAnalysisManager object.
:param model: The model to analyze errors on.
Expand All @@ -210,8 +246,6 @@ def __init__(self, model: Any, dataset: pd.DataFrame, target_column: str,
training. This includes metadata that is useful
for evaluating the model.
:type dropped_features: Optional[List[str]]
:param task_type: The task type of the model.
:type task_type: TaskType
"""
self._true_y = dataset[target_column]
self._dataset = dataset.drop(columns=[target_column])
Expand All @@ -220,13 +254,15 @@ def __init__(self, model: Any, dataset: pd.DataFrame, target_column: str,
self._categorical_features = categorical_features
self._ea_config_list = []
self._ea_report_list = []
self._analyzer = ModelAnalyzer(MetadataRemovalModelWrapper(
model, dropped_features),
wrapper_model = get_wrapped_model(
model,
dropped_features)
self._analyzer = ModelAnalyzer(
wrapper_model,
self._dataset,
self._true_y,
self._feature_names,
self._categorical_features,
model_task=task_type,
classes=self._classes)

def add(self, max_depth: int = 3, num_leaves: int = 31,
Expand Down Expand Up @@ -440,7 +476,6 @@ def _load(path, rai_insights):
inst.__dict__['_ea_report_list'] = ea_report_list
inst.__dict__['_ea_config_list'] = ea_config_list

task_type = rai_insights.task_type
categorical_features = rai_insights.categorical_features
inst.__dict__['_categorical_features'] = categorical_features
target_column = rai_insights.target_column
Expand All @@ -454,11 +489,13 @@ def _load(path, rai_insights):
if rai_insights._feature_metadata is not None:
dropped_features = rai_insights._feature_metadata.dropped_features
inst.__dict__['_dropped_features'] = dropped_features
inst.__dict__['_analyzer'] = ModelAnalyzer(MetadataRemovalModelWrapper(
rai_insights.model, dropped_features),
wrapper_model = get_wrapped_model(
rai_insights.model,
dropped_features)
inst.__dict__['_analyzer'] = ModelAnalyzer(
wrapper_model,
dataset,
true_y,
feature_names,
categorical_features,
model_task=task_type)
categorical_features)
return inst
3 changes: 1 addition & 2 deletions responsibleai/responsibleai/rai_insights/rai_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ def _initialize_managers(self):
self.model, self.test, self.target_column,
self._classes,
self.categorical_features,
dropped_features,
task_type=self.task_type)
dropped_features)

self._explainer_manager = ExplainerManager(
self.model, self.get_train_data(), self.get_test_data(),
Expand Down

0 comments on commit 9d24ea2

Please sign in to comment.