Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updating branch #55

Merged
merged 12 commits into from
Sep 26, 2024
2 changes: 1 addition & 1 deletion .github/workflows/build-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
run: python -m twine check dist/*

- name: Upload artifact
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: Python-package
path: dist
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/codeqc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ jobs:
shell: bash -l {0}
run: |
conda activate torchsurv
./dev/codeqc.sh
./dev/codeqc.sh check

- name: Tests
shell: bash -l {0}
run: |
Expand Down
25 changes: 18 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

![CodeQC](https://github.com/Novartis/torchsurv/actions/workflows/codeqc.yml/badge.svg?branch=main)
![Docs](https://github.com/Novartis/torchsurv/actions/workflows/docs.yml/badge.svg?branch=main)
[![PyPI - Version](https://img.shields.io/pypi/v/torchsurv)](https://pypi.org/project/torchsurv/)
[![arXiv](https://img.shields.io/badge/arXiv-2404.10761-f9f107.svg)](https://arxiv.org/abs/2404.10761)
[![status](https://camo.githubusercontent.com/22fa65b2a659780cddfac609463c5fe719e3ea82a28eb7a61e24b7c4e40eb56d/68747470733a2f2f6a6f73732e7468656f6a2e6f72672f7061706572732f30326437343936646132623963633334663961366530346361626632323938642f7374617475732e737667)](https://joss.theoj.org/papers/02d7496da2b9cc34f9a6e04cabf2298d)
[![PyPI - Version](https://img.shields.io/pypi/v/torchsurv?)](https://pypi.org/project/torchsurv/)
[![Conda](https://img.shields.io/conda/v/conda-forge/torchsurv?label=conda)](https://anaconda.org/conda-forge/torchsurv)
[![arXiv](https://img.shields.io/badge/arXiv-2404.10761-f9f107.svg?)](https://arxiv.org/abs/2404.10761)
[![Documentation](https://img.shields.io/badge/GithubPage-Sphinx-blue)](https://opensource.nibr.com/torchsurv/)
[![Downloads](https://static.pepy.tech/badge/torchsurv)](https://pepy.tech/project/torchsurv)
[![PyPI Downloads](https://img.shields.io/pypi/dm/torchsurv.svg?label=PyPI%20downloads)](
https://pypi.org/project/torchsurv/)
[![Conda Downloads](https://img.shields.io/conda/dn/conda-forge/torchsurv.svg?label=Conda%20downloads)](
https://anaconda.org/conda-forge/torchsurv)

`TorchSurv` is a Python package that serves as a companion tool to perform deep survival modeling within the `PyTorch` environment. Unlike existing libraries that impose specific parametric forms on users, `TorchSurv` enables the use of custom `PyTorch`-based deep survival models. With its lightweight design, minimal input requirements, full `PyTorch` backend, and freedom from restrictive survival model parameterizations, `TorchSurv` facilitates efficient survival model implementation, particularly beneficial for high-dimensional input data scenarios.

Expand Down Expand Up @@ -43,15 +46,23 @@ cindex.compare(cindexB)

## Installation and dependencies

First, install the package:

First, install the package using either [PyPI]([https://pypi.org/](https://pypi.org/project/torchsurv/)) or [Conda]([https://anaconda.org/anaconda/conda](https://anaconda.org/conda-forge/torchsurv))

- Using conda (`recommended`)
```bash
conda install conda-forge::torchsurv
```
- Using PyPI
```bash
pip install torchsurv
```

or for local installation (from package root / clone of this git repository):
- Using for local installation (`latest version`)

```bash
git clone <repo>
cd <repo>
pip install -e .
```

Expand Down Expand Up @@ -237,4 +248,4 @@ If you use this project in academic work or publications, we appreciate citing i
primaryClass={cs.LG},
doi={https://doi.org/10.48550/arXiv.2404.10761}
}
```
```
2 changes: 1 addition & 1 deletion dev/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda-forge
- pytorch
dependencies:
- build=0.7.0
- python-build=1.2.2
- pep517=0.13.0
- numpy=1.26.4
- pandas=2.2.0
Expand Down
10 changes: 10 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Change log
=========

Version 0.1.3 (unreleased)
--------------------------

* Tutorial dataset error on momentum.ipynb #50
* Fix issue #48 - log_hazard returns torch.Inf
* Fix warning with Spearman correlation #41
* Added in-depth statistical background to link AUC to C-index #39
* Created Conda Forge version #47
* Updated CICD builds #53

Version 0.1.2
-------------

Expand Down
4 changes: 2 additions & 2 deletions src/torchsurv/loss/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(
self.memory_k.append(self.survtuple(*list(estimate)))
return loss

@torch.no_grad()
@torch.no_grad() # deactivates autograd
def infer(self, inputs: torch.Tensor) -> torch.Tensor:
"""Evaluate data with target network

Expand All @@ -183,7 +183,7 @@ def infer(self, inputs: torch.Tensor) -> torch.Tensor:
[ 0.9771, -0.8513]])

"""
self.target.eval() # Disable training tricks (augmentation, dropout, etc..)
self.target.eval() # notify all your layers that you are in eval mode
return self.target(inputs)

def _bank_loss(self) -> torch.Tensor:
Expand Down
9 changes: 7 additions & 2 deletions src/torchsurv/loss/weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def log_hazard(
>>> for t in torch.tensor([100.0, 150.0]): log_hazard(log_params, time=t) # Subject-specific log hazard at multiple new times
tensor([ 1.1280, -0.0372, -3.9767, 1.0757])
tensor([ 1.2330, -0.1062, -4.1680, 1.1999])
>>> log_params *= 1e2 # Increase scale
>>> log_hazard(log_params, time, all_times = False) # Check for Torch.Inf values
tensor([-1.0000e+10, -2.3197e+01, -6.8385e+01, -1.0000e+10])
"""

log_scale, log_shape = _check_log_shape(log_params).unbind(1)
Expand All @@ -247,11 +250,13 @@ def log_hazard(
f"Dimension mismatch: 'time' ({len(time)}) does not match the length of 'log_params' ({len(log_params)})."
)

return (
return torch.clamp(
log_shape
- log_scale
+ torch.expm1(log_shape)
* (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale)
* (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale),
min=-TORCH_CLAMP_VALUE,
max=TORCH_CLAMP_VALUE,
)


Expand Down
6 changes: 3 additions & 3 deletions tests/test_kaplan_meier.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def test_kaplan_meier_prediction_error_raised(self):
for batch in batch_container.batches:
(train_time, train_event, test_time, *_) = batch

train_event[
-1
] = False # if last event is censoring, the last KM is > 0 and it cannot predict beyond this time
train_event[-1] = (
False # if last event is censoring, the last KM is > 0 and it cannot predict beyond this time
)
km = KaplanMeierEstimator()
km(train_event, train_time, censoring_dist=False)

Expand Down
Loading