Skip to content

Commit

Permalink
fix combine (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
kegl authored and jorisvandenbossche committed Oct 10, 2018
1 parent 09f007e commit 0156adc
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import datetime

import warnings
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
Expand Down Expand Up @@ -98,6 +99,26 @@ def y_pred_label_index(self):
"""Multi-class y_pred is the index of the predicted label."""
return np.argmax(self.y_pred[:, 1:], axis=1)

@classmethod
def combine(cls, predictions_list, index_list=None):
if index_list is None: # we combine the full list
index_list = range(len(predictions_list))
y_comb_list = np.array(
[predictions_list[i].y_pred for i in index_list])
# clipping probas into [0, 1], also taking care of the case of all
# zeros
y_comb_list[:, :, 1:] = np.clip(
y_comb_list[:, :, 1:], 10 ** -15, 1 - 10 ** -15)
# normalizing probabilities
y_comb_list[:, :, 1:] = y_comb_list[:, :, 1:] / np.sum(
y_comb_list[:, :, 1:], axis=2, keepdims=True)
# I expect to see RuntimeWarnings in this block
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
y_comb = np.nanmean(y_comb_list, axis=0)
combined_predictions = cls(y_pred=y_comb)
return combined_predictions


# -----------------------------------------------------------------------------
# Score types
Expand Down Expand Up @@ -355,7 +376,7 @@ def turn_prediction_to_event_list(y, thres=0.5):


score_types = [
# log-loss
# mixed log-loss/f1 score
Mixed(),
# log-loss
PointwiseLogLoss(),
Expand All @@ -364,10 +385,9 @@ def turn_prediction_to_event_list(y, thres=0.5):
PointwiseRecall(),
# event-based precision and recall
EventwisePrecision(),
EventwiseRecall()
EventwiseRecall(),
]


# -----------------------------------------------------------------------------
# Cross-validation scheme
# -----------------------------------------------------------------------------
Expand Down

0 comments on commit 0156adc

Please sign in to comment.