Skip to content

Commit

Permalink
Merge pull request #49 from Novartis/48-log_hazard-returns-torchinf
Browse files Browse the repository at this point in the history
resolve torch.Inf issue
  • Loading branch information
melodiemonod authored Sep 23, 2024
2 parents 41f5911 + 015e45d commit 723192f
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/torchsurv/loss/weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def log_hazard(
>>> for t in torch.tensor([100.0, 150.0]): log_hazard(log_params, time=t) # Subject-specific log hazard at multiple new times
tensor([ 1.1280, -0.0372, -3.9767, 1.0757])
tensor([ 1.2330, -0.1062, -4.1680, 1.1999])
>>> log_params *= 1e2 # Increase scale
>>> log_hazard(log_params, time, all_times = False) # Check for Torch.Inf values
tensor([-1.0000e+10, -2.3197e+01, -6.8385e+01, -1.0000e+10])
"""

log_scale, log_shape = _check_log_shape(log_params).unbind(1)
Expand All @@ -247,11 +250,13 @@ def log_hazard(
f"Dimension mismatch: 'time' ({len(time)}) does not match the length of 'log_params' ({len(log_params)})."
)

return (
return torch.clamp(
log_shape
- log_scale
+ torch.expm1(log_shape)
* (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale)
* (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale),
min=-TORCH_CLAMP_VALUE,
max=TORCH_CLAMP_VALUE,
)


Expand Down

0 comments on commit 723192f

Please sign in to comment.