Skip to content

Commit

Permalink
Set predicted recall to zero for docs without predicted labels
Browse files Browse the repository at this point in the history
  • Loading branch information
Christopher Bartz committed Oct 4, 2022
1 parent ed521ea commit 22b631d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 9 deletions.
70 changes: 61 additions & 9 deletions qualle/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import List, Callable, Any
from typing import List, Callable, Any, Collection

from sklearn.model_selection import cross_val_predict

Expand Down Expand Up @@ -65,15 +65,67 @@ def train(self, data: TrainData):
self._recall_predictor.fit(features_data, true_recall)

def predict(self, data: PredictData) -> List[float]:
predicted_no_of_labels = self._label_calibrator.predict(data.docs)
label_calibration_data = LabelCalibrationData(
predicted_labels=data.predicted_labels,
predicted_no_of_labels=predicted_no_of_labels
)
features_data = self._features_data_mapper(
data, label_calibration_data
zero_idxs = self._get_pdata_idxs_with_zero_labels(data)
data_with_labels = self._get_pdata_with_labels(data, zero_idxs)
if data_with_labels.docs:
predicted_no_of_labels = self._label_calibrator.predict(
data_with_labels.docs
)
label_calibration_data = LabelCalibrationData(
predicted_labels=data_with_labels.predicted_labels,
predicted_no_of_labels=predicted_no_of_labels,
)
features_data = self._features_data_mapper(
data_with_labels, label_calibration_data
)
predicted_recall = self._recall_predictor.predict(
features_data
)
recall_scores = self._merge_zero_recall_with_predicted_recall(
predicted_recall=predicted_recall,
zero_labels_idx=zero_idxs,
)
else:
recall_scores = [0] * len(data.predicted_labels)
return recall_scores

@staticmethod
def _get_pdata_idxs_with_zero_labels(data: PredictData) -> Collection[int]:
return [
i for i in range(len(data.predicted_labels))
if not data.predicted_labels[i]
]

@staticmethod
def _get_pdata_with_labels(
data: PredictData, zero_labels_idxs: Collection[int]
) -> PredictData:
non_zero_idxs = [
i for i in range(len(data.predicted_labels))
if i not in zero_labels_idxs
]
return PredictData(
docs=[data.docs[i] for i in non_zero_idxs],
predicted_labels=[data.predicted_labels[i] for i in non_zero_idxs],
scores=[data.scores[i] for i in non_zero_idxs],
)
return self._recall_predictor.predict(features_data)

@staticmethod
def _merge_zero_recall_with_predicted_recall(
predicted_recall: List[float],
zero_labels_idx: Collection[int],
):
recall_scores = []
j = 0
for i in range(
len(zero_labels_idx) +
len(predicted_recall)):
if i in zero_labels_idx:
recall_scores.append(0)
else:
recall_scores.append(predicted_recall[j])
j += 1
return recall_scores

@contextmanager
def _debug(self, method_name):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def train_data():
)


@pytest.fixture
def train_data_with_some_empty_labels(train_data):
train_data.predict_data.predicted_labels = [['c'], [], ['c'], [], ['c']]

return train_data


@pytest.fixture
def train_data_with_all_empty_labels(train_data):
train_data.predict_data.predicted_labels = [[]] * 5

return train_data


def test_train(qp, train_data, mocker):
calibrator = qp._label_calibrator
mocker.spy(calibrator, 'fit')
Expand Down Expand Up @@ -99,6 +113,24 @@ def test_predict(qp, train_data):
assert np.array_equal(qp.predict(p_data), [1] * 5)


def test_predict_with_some_empty_labels_returns_zero_recall(
qp, train_data_with_some_empty_labels):
p_data = train_data_with_some_empty_labels.predict_data

qp.train(train_data_with_some_empty_labels)

assert np.array_equal(qp.predict(p_data), [1, 0, 1, 0, 1])


def test_predict_with_all_empty_labels_returns_only_zero_recall(
qp, train_data_with_all_empty_labels):
p_data = train_data_with_all_empty_labels.predict_data

qp.train(train_data_with_all_empty_labels)

assert np.array_equal(qp.predict(p_data), [0] * 5)


def test_debug_prints_time_if_activated(qp, caplog):
qp._should_debug = True
caplog.set_level(logging.DEBUG)
Expand Down

0 comments on commit 22b631d

Please sign in to comment.