Skip to content

Commit

Permalink
First Pass Documentation (docstrings) For Top Level Pyrenew Files (#89)
Browse files Browse the repository at this point in the history
* numpydoc changes (testing)

* numpydoc changes (testing)

* poetry lock and precommit for top-level docs

* docs for convolve; poetry lock; ignore TODOs

* docs for convolve; poetry lock; ignore TODOs

* distutil

* math docs

* minor mcmcutils edit

* update poetry lock file

* top level file modifications

* pre-commit edit
  • Loading branch information
AFg6K7h4fhy2 authored Apr 22, 2024
1 parent 437016d commit afb93a9
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 45 deletions.
70 changes: 65 additions & 5 deletions model/src/pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,88 @@
jax.lax.scan. Factories generate functions
that can be passed to scan with an
appropriate array to scan.
Notes
-----
TODO: Look into adding blocks for Functions and Examples in this
docstring.
"""
from typing import Callable, Tuple

import jax.numpy as jnp
from jax.typing import ArrayLike


def new_convolve_scanner(discrete_dist_flipped: ArrayLike) -> Callable:
"""
Factory function to create a scanner function for convolving a discrete distribution
over a time series data subset.
Parameters
----------
discrete_dist_flipped : ArrayLike
A 1D jax array representing the discrete distribution flipped for convolution.
Returns
-------
Callable
A scanner function that can be used with jax.lax.scan for convolution.
This function takes a history subset and a multiplier, computes the dot product,
then updates and returns the new history subset and the convolution result.
def new_convolve_scanner(discrete_dist_flipped):
def _new_scanner(history_subset, multiplier):
Notes
-----
TODO: Add Example.
TODO: Clarification on Returns description.
"""

def _new_scanner(
history_subset: ArrayLike, multiplier: float
) -> Tuple[ArrayLike, float]:
new_val = multiplier * jnp.dot(discrete_dist_flipped, history_subset)
latest = jnp.hstack([history_subset[1:], new_val])
return latest, new_val

return _new_scanner


def new_double_scanner(dists, transforms):
def new_double_scanner(
dists: Tuple[ArrayLike, ArrayLike], transforms: Tuple[Callable, Callable]
) -> Callable:
"""
Factory function to create a scanner function that applies two sequential transformations
and convolutions using two discrete distributions.
Parameters
----------
dists : Tuple[ArrayLike, ArrayLike]
A tuple of two 1D jax arrays, each representing a discrete distribution for the
two stages of convolution.
transforms : Tuple[Callable, Callable]
A tuple of two functions, each transforming the output of the dot product at each
convolution stage.
Returns
-------
Callable
A scanner function that applies two sequential convolutions and transformations.
It takes a history subset and a tuple of multipliers, computes the transformations
and dot products, and returns the updated history subset and a tuple of new values.
Notes
-----
TODO: Add Example
"""
d1, d2 = dists
t1, t2 = transforms

def _new_scanner(history_subset, multipliers):
def _new_scanner(
history_subset: jnp.ndarray, multipliers: Tuple[float, float]
) -> (jnp.ndarray, Tuple[float, float]):
m1, m2 = multipliers
m_net1 = t1(m1 * jnp.dot(d1, history_subset))
new_val = t2(m2 * m_net1 * jnp.dot(d2, history_subset))
latest = jnp.hstack([history_subset[1:], new_val])
return (latest, (new_val, m_net1))
return latest, (new_val, m_net1)

return _new_scanner
38 changes: 35 additions & 3 deletions model/src/pyrenew/distutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,37 @@
such as discrete time-to-event distributions
"""
import jax.numpy as jnp
from jax.typing import ArrayLike


def validate_discrete_dist_vector(
discrete_dist: jnp.ndarray, tol: float = 1e-20
) -> bool:
discrete_dist: ArrayLike, tol: float = 1e-20
) -> ArrayLike:
"""
Validate that a vector represents a discrete
probability distribution to within a specified
tolerance, raising a ValueError if not.
Parameters
----------
discrete_dist : ArrayLike
An jax array containing non-negative values that
represent a discrete probability distribution. The values
must sum to 1 within the specified tolerance.
tol : float, optional
The tolerance within which the sum of the distribution must
be 1. Defaults to 1e-20.
Returns
-------
ArrayLike
The normalized distribution array if the input is valid.
Raises
------
ValueError
If any value in discrete_dist is negative or if the sum of the
distribution does not equal 1 within the specified tolerance.
"""
discrete_dist = discrete_dist.flatten()
if not jnp.all(discrete_dist >= 0):
Expand All @@ -39,10 +61,20 @@ def validate_discrete_dist_vector(
return discrete_dist / dist_norm


def reverse_discrete_dist_vector(dist):
def reverse_discrete_dist_vector(dist: ArrayLike) -> ArrayLike:
"""
Reverse a discrete distribution
vector (useful for discrete
time-to-event distributions).
Parameters
----------
dist : ArrayLike
A discrete distribution vector (likely discrete time-to-event distribution)
Returns
-------
ArrayLike
A reversed (jnp.flip) discrete distribution vector
"""
return jnp.flip(dist)
53 changes: 35 additions & 18 deletions model/src/pyrenew/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
a given renewal process.
"""


import jax.numpy as jnp
from jax.typing import ArrayLike
from pyrenew.distutil import validate_discrete_dist_vector


def get_leslie_matrix(R, generation_interval_pmf):
def get_leslie_matrix(R: float, generation_interval_pmf: ArrayLike) -> float:
"""
Create the Leslie matrix
corresponding to a basic
Expand All @@ -28,7 +30,7 @@ def get_leslie_matrix(R, generation_interval_pmf):
mass vector of the renewal process
Returns
--------
-------
The Leslie matrix for the
renewal process, as a jax array.
"""
Expand All @@ -44,14 +46,17 @@ def get_leslie_matrix(R, generation_interval_pmf):
return jnp.vstack([R * generation_interval_pmf, aging_matrix])


def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf):
def get_asymptotic_growth_rate_and_age_dist(
R: float, generation_interval_pmf: ArrayLike
) -> tuple[float, ArrayLike]:
"""
Get the asymptotic per-timestep growth
rate of the renewal process (the dominant
eigenvalue of its Leslie matrix) and the
associated stable age distribution
(a normalized eigenvector associated to
that eigenvalue).
Parameters
----------
R : float
Expand All @@ -61,11 +66,17 @@ def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf):
mass vector of the renewal process
Returns
--------
A tuple consisting of the asymptotic growth rate of
the process, as jax float, and the stable age distribution
of the process, as a jax array probability vector of the
same shape as the generation interval probability vector.
-------
tuple[float, ArrayLike]
A tuple consisting of the asymptotic growth rate of
the process, as jax float, and the stable age distribution
of the process, as a jax array probability vector of the
same shape as the generation interval probability vector.
Raises
------
ValueError
If an age distribution vector with non-zero imaginary part is produced.
"""
L = get_leslie_matrix(R, generation_interval_pmf)
eigenvals, eigenvecs = jnp.linalg.eig(L)
Expand All @@ -92,7 +103,9 @@ def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf):
return d_val_real, d_vec_norm


def get_stable_age_distribution(R, generation_interval_pmf):
def get_stable_age_distribution(
R: float, generation_interval_pmf: ArrayLike
) -> ArrayLike:
"""
Get the stable age distribution for a
renewal process with a given value of
Expand All @@ -114,18 +127,21 @@ def get_stable_age_distribution(R, generation_interval_pmf):
mass vector of the renewal process
Returns
--------
The stable age distribution for the
process, as a jax array probability vector of
the same shape as the generation interval
probability vector.
-------
ArrayLike
The stable age distribution for the
process, as a jax array probability vector of
the same shape as the generation interval
probability vector.
"""
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
1
]


def get_asymptotic_growth_rate(R, generation_interval_pmf):
def get_asymptotic_growth_rate(
R: float, generation_interval_pmf: ArrayLike
) -> float:
"""
Get the asymptotic per timestep growth rate
for a renewal process with a given value of
Expand All @@ -145,9 +161,10 @@ def get_asymptotic_growth_rate(R, generation_interval_pmf):
mass vector of the renewal process
Returns
--------
The asymptotic growth rate of the renewal process,
as a jax float.
-------
float
The asymptotic growth rate of the renewal process,
as a jax float.
"""
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
0
Expand Down
5 changes: 3 additions & 2 deletions model/src/pyrenew/mcmcutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def spread_draws(
posteriors: dict,
variables_names: list,
variables_names: list[str] | list[tuple],
) -> pl.DataFrame:
"""Get nicely shaped draws from the posterior
Expand All @@ -29,7 +29,8 @@ def spread_draws(
Returns
-------
polars.DataFrame
pl.DataFrame
A dataframe of draw-indexed
"""

for i_var, v in enumerate(variables_names):
Expand Down
49 changes: 42 additions & 7 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@ def _assert_sample_and_rtype(
) -> None:
"""Return type-checking for RandomVariable's sample function
Objects passed as `RandomVariable` should (a) have a `sample()` method that
Objects passed as `RandomVariable` should (a) have a sample() method that
(b) returns either a tuple or a named tuple.
Parameters
----------
rp : RandomVariable
Random variable to check.
skip_if_none: bool
When `True` it returns if `rp` is None.
skip_if_none: bool, optional
When `True` it returns if `rp` is None. Defaults to True
Returns
-------
None
Raises
------
Exception
If rp is not a RandomVariable, does not have a sample function, or
does not return a tuple. Also occurs if rettype does not initialized
properly.
"""

# Addressing the None case
Expand Down Expand Up @@ -101,7 +108,7 @@ def sample(
----------
**kwargs : dict, optional
Additional keyword arguments passed through to internal `sample()`
calls, if any
calls, should there be any.
Notes
-----
Expand All @@ -117,6 +124,9 @@ def sample(
@staticmethod
@abstractmethod
def validate(**kwargs) -> None:
"""
Validation of kwargs to be implemented in subclasses.
"""
pass


Expand Down Expand Up @@ -149,7 +159,7 @@ def sample(
----------
**kwargs : dict, optional
Additional keyword arguments passed through to internal `sample()`
calls, if any
calls, should there be any.
Notes
-----
Expand Down Expand Up @@ -244,9 +254,34 @@ def print_summary(
prob: float = 0.9,
exclude_deterministic: bool = True,
) -> None:
"""A wrapper of MCMC.print_summary"""
"""
A wrapper of MCMC.print_summary
Parameters
----------
prob : float, optional
The acceptance probability of print_summary. Defaults to 0.9
exclude_deterministic : bool, optional.
Whether to print deterministic variables in the summary.
Defaults to True.
Returns
-------
None
"""
return self.mcmc.print_summary(prob, exclude_deterministic)

def spread_draws(self, variables_names: list) -> pl.DataFrame:
"""A wrapper of mcmcutils.spread_draws"""
"""A wrapper of mcmcutils.spread_draws
Parameters
----------
variables_names : list
A list of variable names to create a table of samples.
Returns
-------
pl.DataFrame
"""

return spread_draws(self.mcmc.get_samples(), variables_names)
Loading

0 comments on commit afb93a9

Please sign in to comment.