Skip to content

Commit

Permalink
Merge branch 'Novartis:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
qiancao authored Jul 19, 2024
2 parents a78fab4 + fa1d306 commit af5f5bf
Show file tree
Hide file tree
Showing 14 changed files with 300 additions and 41 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,16 @@ The table below compares the functionalities of `TorchSurv` with those of
[deepsurv](https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-018-0482-1).
While several libraries offer survival modelling functionalities, no existing library provides the flexibility to use a custom PyTorch-based neural networks to define the survival model parameters.

The outputs of both the log-likelihood functions and the evaluation metrics functions have undergone thorough comparison with benchmarks generated using Python packages and R packages. The comparisons are summarised in the [Related packages summary](https://opensource.nibr.com/torchsurv/benchmarks.html).
The outputs of both the log-likelihood functions and the evaluation metrics functions have **undergone thorough comparison with benchmarks generated** using `Python` and `R` packages. The comparisons (at time of publication) are summarised in the [Related packages summary](https://opensource.nibr.com/torchsurv/benchmarks.html).

![Survival analysis libraries in Python](docs/source/table_python_benchmark.png)
![Survival analysis libraries in Python](docs/source/table_python_benchmark_legend.png)

Survival analysis libraries in R. For obtaining the evaluation metrics, packages `survival`, `riskRegression`, `SurvMetrics` and `pec` require the fitted model object as input (a specific object format) and `RisksetROC` imposes a smoothing method. Packages `timeROC`, `riskRegression` and pec force the user to choose a form for subject-specific
weights (e.g., inverse probability of censoring weighting (IPCW)). Packages `survcomp` and `SurvivalROC` do not implement the general AUC but the censoring-adjusted AUC estimator proposed by Heagerty et al. (2000).

![Survival analysis libraries in R](docs/source/table_r_benchmark.png)

## Contributing

We value contributions from the community to enhance and improve this project. If you'd like to contribute, please consider the following:
Expand Down Expand Up @@ -231,4 +236,4 @@ If you use this project in academic work or publications, we appreciate citing i
primaryClass={cs.LG},
doi={https://doi.org/10.48550/arXiv.2404.10761}
}
```
```
2 changes: 2 additions & 0 deletions dev/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ dependencies:
- nbsphinx=0.9.3
- ipython=8.20.0
- twine=4.0.2
- scipy=1.12.0
- sphinx-math-dollar=1.2.1
1 change: 0 additions & 1 deletion docs/AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ Contributors
* David Ohlssen <[email protected]> `(contributor)`
* Berkman Sahiner <[email protected]> `(contributor)`
* Nicholas Petrick <[email protected]> `(contributor)`

2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
"nbsphinx",
"sphinx.ext.viewcode",
"sphinxcontrib.bibtex",
"sphinx_math_dollar",
]


# templates_path = ['_templates']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Welcome to TorchSurv's documentation!
:maxdepth: 2
:caption: Tutorials:

notebooks/survival
notebooks/introduction
notebooks/momentum

Expand Down
264 changes: 264 additions & 0 deletions docs/notebooks/survival.md

Large diffs are not rendered by default.

Binary file added docs/source/table_r_benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 0 additions & 6 deletions src/torchsurv/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,6 @@ def _find_torch_unique_indices(
def _validate_auc_inputs(
estimate, time, auc_type, new_time, weight, weight_new_time
):

# check new_time and weight are provided, weight_new_time should be provided
if all([new_time is not None, weight is not None, weight_new_time is None]):
raise ValueError(
Expand Down Expand Up @@ -1222,12 +1221,10 @@ def _update_auc_new_time(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# update new time
if (
new_time is not None
): # if new_time are specified: ensure it has the correct format

# ensure that new_time are float
if isinstance(new_time, int):
new_time = torch.tensor([new_time]).float()
Expand All @@ -1237,7 +1234,6 @@ def _update_auc_new_time(
new_time = new_time.unsqueeze(0)

else: # else: find new_time

# if new_time are not specified, use unique event time
mask = event & (time < torch.max(time))
new_time, inverse_indices, counts = torch.unique(
Expand All @@ -1261,7 +1257,6 @@ def _update_auc_new_time(
def _update_auc_estimate(
estimate: torch.tensor, new_time: torch.tensor
) -> torch.tensor:

# squeeze estimate if shape = (n_samples, 1)
if estimate.ndim == 2 and estimate.shape[1] == 1:
estimate = estimate.squeeze(1)
Expand All @@ -1281,7 +1276,6 @@ def _update_auc_weight(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# if weight was not specified, weight of 1
if weight is None:
weight = torch.ones_like(time)
Expand Down
27 changes: 13 additions & 14 deletions src/torchsurv/metrics/brier_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,13 @@ def __call__(
)

# update inputs as required
estimate, new_time, weight, weight_new_time = (
BrierScore._update_brier_score_new_time(
estimate, time, new_time, weight, weight_new_time
)
(
estimate,
new_time,
weight,
weight_new_time,
) = BrierScore._update_brier_score_new_time(
estimate, time, new_time, weight, weight_new_time
)
weight, weight_new_time = BrierScore._update_brier_score_weight(
time, new_time, weight, weight_new_time
Expand All @@ -190,14 +193,14 @@ def __call__(

# Calculating the residuals for each subject and time point
residuals = torch.zeros_like(estimate)
for i, t in enumerate(new_time):
est = estimate[:, i]
is_case = ((time <= t) & (event)).int()
is_control = (time > t).int()
for index, new_time_i in enumerate(new_time):
est = estimate[:, index]
is_case = ((time <= new_time_i) & (event)).int()
is_control = (time > new_time_i).int()

residuals[:, i] = (
residuals[:, index] = (
torch.square(est) * is_case * weight
+ torch.square(1.0 - est) * is_control * weight_new_time[i]
+ torch.square(1.0 - est) * is_control * weight_new_time[index]
)

# Calculating the brier scores at each time point
Expand Down Expand Up @@ -827,7 +830,6 @@ def _validate_brier_score_inputs(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# check new_time and weight are provided, weight_new_time should be provided
if all([new_time is not None, weight is not None, weight_new_time is None]):
raise ValueError(
Expand Down Expand Up @@ -859,7 +861,6 @@ def _update_brier_score_new_time(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# check format of new_time
if (
new_time is not None
Expand All @@ -871,7 +872,6 @@ def _update_brier_score_new_time(
new_time = new_time.unsqueeze(0)

else: # else: find new_time

# if new_time are not specified, use unique time
new_time, inverse_indices, counts = torch.unique(
time, sorted=True, return_inverse=True, return_counts=True
Expand All @@ -896,7 +896,6 @@ def _update_brier_score_weight(
weight: torch.tensor,
weight_new_time: torch.tensor,
) -> torch.tensor:

# if weight was not specified, weight of 1
if weight is None:
weight = torch.ones_like(time)
Expand Down
8 changes: 8 additions & 0 deletions src/torchsurv/metrics/cindex.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import sys
import warnings
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -639,6 +640,13 @@ def _compare_noether(self, other):
cindex1_se = self._concordance_index_se()
cindex2_se = other._concordance_index_se()

# Suppress the specific warning
warnings.filterwarnings(
"ignore",
message="Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.",
category=UserWarning,
)

# compute spearman correlation between risk prediction
corr = regression.SpearmanCorrCoef()(
self.estimate.reshape(-1), other.estimate.reshape(-1)
Expand Down
4 changes: 0 additions & 4 deletions tests/test_cox.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def test_boolean_y(self):
def test_log_likelihood_without_ties(self):
"""test cox partial log likelihood without ties on lung and gbsg datasets"""
for benchmark_cox_loglik in benchmark_cox_logliks:

if benchmark_cox_loglik["no_ties"][0] == True:

log_lik = -cox(
torch.tensor(
benchmark_cox_loglik["log_hazard"], dtype=torch.float32
Expand All @@ -82,9 +80,7 @@ def test_log_likelihood_without_ties(self):
def test_log_likelihood_with_ties(self):
"""test Efron and Breslow's approximation of cox partial log likelihood with ties on lung and gbsg data"""
for benchmark_cox_loglik in benchmark_cox_logliks:

if benchmark_cox_loglik["no_ties"][0] == False:

# efron approximation of partial log likelihood
log_lik_efron = -cox(
torch.tensor(
Expand Down
7 changes: 3 additions & 4 deletions tests/test_kaplan_meier.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_kaplan_meier_survival_distribution_real_data(self):
"""test Kaplan Meier survival distribution estimate on lung and gbsg datasets"""

for benchmark_kaplan_meier in benchmark_kaplan_meiers:

event = torch.tensor(benchmark_kaplan_meier["status"]).bool()
time = torch.tensor(benchmark_kaplan_meier["time"], dtype=torch.float32)
new_time = torch.tensor(
Expand Down Expand Up @@ -209,9 +208,9 @@ def test_kaplan_meier_prediction_error_raised(self):
for batch in batch_container.batches:
(train_time, train_event, test_time, *_) = batch

train_event[-1] = (
False # if last event is censoring, the last KM is > 0 and it cannot predict beyond this time
)
train_event[
-1
] = False # if last event is censoring, the last KM is > 0 and it cannot predict beyond this time
km = KaplanMeierEstimator()
km(train_event, train_time, censoring_dist=False)

Expand Down
1 change: 0 additions & 1 deletion tests/test_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class TestMometum(unittest.TestCase):

def test_momentum_weibull(self):
model = Momentum(
backbone=nn.Sequential(nn.Linear(8, 2)), # Weibull expect two ouputs
Expand Down
9 changes: 0 additions & 9 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def get_input_array(self) -> Tuple[np.array, np.array, np.array, np.array]:
)

def _generate_input(self):

# random maximum time in observational period
tmax = torch.randint(5, 500, (1,)).item()

Expand All @@ -297,7 +296,6 @@ def _generate_input(self):
self._generate_new_time()

def _generate_data(self, tmax: int, n_train: int, n_test: int):

# time-to-event or censoring in train
train_time = torch.randint(1, tmax + 1, (n_train,)).float()

Expand Down Expand Up @@ -340,7 +338,6 @@ def _generate_data(self, tmax: int, n_train: int, n_test: int):
def _enforce_conditions_data(
self, time: torch.tensor, event: torch.tensor, dataset_type: str
) -> Tuple[torch.tensor, torch.tensor]:

# if test max time should be greater than train max time
if dataset_type == "test":
if self.test_max_time_gt_train_max_time:
Expand Down Expand Up @@ -395,15 +392,13 @@ def _enforce_conditions_data(
return time, event

def _generate_estimate(self):

# random risk score for observations in test
estimate = torch.randn(len(self.test_event))

# enforce conditions risk score
self.estimate = self._enforce_conditions_estimate(estimate)

def _enforce_conditions_estimate(self, estimate: torch.tensor) -> torch.tensor:

# if there should be ties in risk score associated to patients with event
if self.ties_score_events:
estimate[torch.where(self.test_event == 1.0)[0][0]] = estimate[
Expand All @@ -425,7 +420,6 @@ def _enforce_conditions_estimate(self, estimate: torch.tensor) -> torch.tensor:
return estimate

def _generate_new_time(self):

if torch.all(self.test_event == False):
# if all patients are censored in test, no evaluation time
new_time = torch.tensor([])
Expand All @@ -447,7 +441,6 @@ def _generate_new_time(self):
self.new_time = self._enforce_conditions_time(new_time)

def _enforce_conditions_time(self, new_time: torch.tensor) -> torch.tensor:

# if the test max time should be included in evaluation time
if self.test_max_time_in_new_time:
new_time = torch.cat(
Expand All @@ -457,7 +450,6 @@ def _enforce_conditions_time(self, new_time: torch.tensor) -> torch.tensor:
return new_time

def _evaluate_conditions(self):

# are there ties in event times
self.has_train_ties_time_event = self._has_ties(
self.train_time[self.train_event == 1]
Expand Down Expand Up @@ -614,7 +606,6 @@ def generate_batches(self, n_batch: int, flags_to_set: list):
n_batch = len(flags_to_set)

for i in range(n_batch):

if i >= len(flags_to_set):
# simulate data without flag
self.generate_one_batch()
Expand Down

0 comments on commit af5f5bf

Please sign in to comment.