Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Accuracy in time #75

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
54 changes: 54 additions & 0 deletions hazardous/metrics/_accuracy_in_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np

from hazardous.utils import check_y_survival


def accuracy_in_time(y_test, y_pred, times, quantiles=None, taus=None):
Vincent-Maladiere marked this conversation as resolved.
Show resolved Hide resolved
event_true, _ = check_y_survival(y_test)

if y_pred.ndim != 3:
raise ValueError(
"'y_pred' must be a 3D array with shape (n_samples, n_events, n_times), got"
f" shape {y_pred.shape}."
)
if y_pred.shape[0] != event_true.shape[0]:
raise ValueError(
"'y_true' and 'y_pred' must have the same number of samples, "
f"got {event_true.shape[0]} and {y_pred.shape[0]} respectively."
)
times = np.atleast_1d(times)
if y_pred.shape[1] != times.shape[0]:
raise ValueError(
f"'times' length ({times.shape[0]}) "
f"must be equal to y_pred.shape[1] ({y_pred.shape[1]})."
)

if quantiles is not None:
if taus is not None:
raise ValueError("'quantiles' and 'taus' can't be set at the same time.")

quantiles = np.atleast_1d(quantiles)
if any(quantiles < 0) or any(quantiles > 1):
raise ValueError(f"quantiles must be in [0, 1], got {quantiles}.")
taus = np.quantile(times, quantiles)

elif quantiles is None and taus is None:
n_quantiles = min(times.shape[0], 8)
quantiles = np.linspace(1 / n_quantiles, 1, n_quantiles)
taus = np.quantile(times, quantiles)

acc_in_time = []

for tau in taus:
mask_past_censored = (y_test["event"] == 0) & (y_test["duration"] < tau)
Vincent-Maladiere marked this conversation as resolved.
Show resolved Hide resolved

tau_idx = np.searchsorted(times, tau)
y_pred_at_t = y_pred[:, :, tau_idx]
Vincent-Maladiere marked this conversation as resolved.
Show resolved Hide resolved
y_pred_class = y_pred_at_t[~mask_past_censored, :].argmax(axis=1)

y_test_class = y_test["event"] * (y_test["duration"] < tau)
Vincent-Maladiere marked this conversation as resolved.
Show resolved Hide resolved
y_test_class = y_test_class.loc[~mask_past_censored]
juAlberge marked this conversation as resolved.
Show resolved Hide resolved

acc_in_time.append((y_test_class.values == y_pred_class).mean())

return acc_in_time, taus
1 change: 1 addition & 0 deletions hazardous/metrics/_brier_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def brier_score_incidence(self, y_true, y_pred, times):
"'y_true' and 'y_pred' must have the same number of samples, "
f"got {event_true.shape[0]} and {y_pred.shape[0]} respectively."
)
times = np.atleast_1d(times)
if y_pred.shape[1] != times.shape[0]:
raise ValueError(
f"'times' length ({times.shape[0]}) "
Expand Down
Loading