From 22b631d8fbf231fed1fa17b9d288a0ce15b49e15 Mon Sep 17 00:00:00 2001 From: Christopher Bartz Date: Tue, 4 Oct 2022 15:14:24 +0200 Subject: [PATCH] Set predicted recall to zero for docs without predicted labels --- qualle/pipeline.py | 70 ++++++++++++++++++++++++++++++++++++------ tests/test_pipeline.py | 32 +++++++++++++++++++ 2 files changed, 93 insertions(+), 9 deletions(-) diff --git a/qualle/pipeline.py b/qualle/pipeline.py index 54967e6..085551e 100644 --- a/qualle/pipeline.py +++ b/qualle/pipeline.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import List, Callable, Any +from typing import List, Callable, Any, Collection from sklearn.model_selection import cross_val_predict @@ -65,15 +65,67 @@ def train(self, data: TrainData): self._recall_predictor.fit(features_data, true_recall) def predict(self, data: PredictData) -> List[float]: - predicted_no_of_labels = self._label_calibrator.predict(data.docs) - label_calibration_data = LabelCalibrationData( - predicted_labels=data.predicted_labels, - predicted_no_of_labels=predicted_no_of_labels - ) - features_data = self._features_data_mapper( - data, label_calibration_data + zero_idxs = self._get_pdata_idxs_with_zero_labels(data) + data_with_labels = self._get_pdata_with_labels(data, zero_idxs) + if data_with_labels.docs: + predicted_no_of_labels = self._label_calibrator.predict( + data_with_labels.docs + ) + label_calibration_data = LabelCalibrationData( + predicted_labels=data_with_labels.predicted_labels, + predicted_no_of_labels=predicted_no_of_labels, + ) + features_data = self._features_data_mapper( + data_with_labels, label_calibration_data + ) + predicted_recall = self._recall_predictor.predict( + features_data + ) + recall_scores = self._merge_zero_recall_with_predicted_recall( + predicted_recall=predicted_recall, + zero_labels_idx=zero_idxs, + ) + else: + recall_scores = [0] * len(data.predicted_labels) + return recall_scores + + @staticmethod + def _get_pdata_idxs_with_zero_labels(data: PredictData) -> Collection[int]: + return [ + i for i in range(len(data.predicted_labels)) + if not data.predicted_labels[i] + ] + + @staticmethod + def _get_pdata_with_labels( + data: PredictData, zero_labels_idxs: Collection[int] + ) -> PredictData: + non_zero_idxs = [ + i for i in range(len(data.predicted_labels)) + if i not in zero_labels_idxs + ] + return PredictData( + docs=[data.docs[i] for i in non_zero_idxs], + predicted_labels=[data.predicted_labels[i] for i in non_zero_idxs], + scores=[data.scores[i] for i in non_zero_idxs], ) - return self._recall_predictor.predict(features_data) + + @staticmethod + def _merge_zero_recall_with_predicted_recall( + predicted_recall: List[float], + zero_labels_idx: Collection[int], + ): + recall_scores = [] + j = 0 + for i in range( + len(zero_labels_idx) + + len(predicted_recall)): + if i in zero_labels_idx: + recall_scores.append(0) + else: + recall_scores.append(predicted_recall[j]) + j += 1 + return recall_scores @contextmanager def _debug(self, method_name): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 2fe3550..0aa9403 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -58,6 +58,20 @@ def train_data(): ) +@pytest.fixture +def train_data_with_some_empty_labels(train_data): + train_data.predict_data.predicted_labels = [['c'], [], ['c'], [], ['c']] + + return train_data + + +@pytest.fixture +def train_data_with_all_empty_labels(train_data): + train_data.predict_data.predicted_labels = [[]] * 5 + + return train_data + + def test_train(qp, train_data, mocker): calibrator = qp._label_calibrator mocker.spy(calibrator, 'fit') @@ -99,6 +113,24 @@ def test_predict(qp, train_data): assert np.array_equal(qp.predict(p_data), [1] * 5) +def test_predict_with_some_empty_labels_returns_zero_recall( + qp, train_data_with_some_empty_labels): + p_data = train_data_with_some_empty_labels.predict_data + + qp.train(train_data_with_some_empty_labels) + + assert np.array_equal(qp.predict(p_data), [1, 0, 1, 0, 1]) + + +def test_predict_with_all_empty_labels_returns_only_zero_recall( + qp, train_data_with_all_empty_labels): + p_data = train_data_with_all_empty_labels.predict_data + + qp.train(train_data_with_all_empty_labels) + + assert np.array_equal(qp.predict(p_data), [0] * 5) + + def test_debug_prints_time_if_activated(qp, caplog): qp._should_debug = True caplog.set_level(logging.DEBUG)