diff --git a/responsibleai_text/rai_text_insights/rai_text_insights.py b/responsibleai_text/rai_text_insights/rai_text_insights.py index 3a54a6c070..c4c04c2109 100644 --- a/responsibleai_text/rai_text_insights/rai_text_insights.py +++ b/responsibleai_text/rai_text_insights/rai_text_insights.py @@ -15,7 +15,6 @@ from erroranalysis._internal.cohort_filter import FilterDataWithCohortFilters from ml_wrappers import wrap_model from raiutils.data_processing import convert_to_list, serialize_json_safe -from raiutils.models import ModelTask as RAIModelTask from raiutils.models import SKLearn, is_classifier from responsibleai._interfaces import Dataset, RAIInsightsData from responsibleai._internal.constants import (ManagerNames, Metadata, @@ -432,15 +431,13 @@ def _is_classification_task(self): def _get_dataset(self): dashboard_dataset = Dataset() tasktype = self.task_type - is_classification_task = self._is_classification_task - if is_classification_task: - tasktype = RAIModelTask.CLASSIFICATION if isinstance(tasktype, Enum): tasktype = tasktype.value dashboard_dataset.task_type = tasktype dashboard_dataset.categorical_features = [] dashboard_dataset.class_names = convert_to_list( self._classes) + is_classification_task = self._is_classification_task dataset = self._get_test_without_target(is_classification_task) predicted_y = None