Skip to content

Commit

Permalink
Vectorize ARProcess() (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Sep 12, 2024
1 parent bfaa0c0 commit 3d06ab9
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 147 deletions.
177 changes: 125 additions & 52 deletions pyrenew/process/ar.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# numpydoc ignore=GL08
"""
This file defines a RandomVariable subclass for
autoregressive (AR) processes
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike
from numpyro.contrib.control_flow import scan
from numpyro.infer.reparam import LocScaleReparam

from pyrenew.metaclass import RandomVariable
from pyrenew.process.iidrandomsequence import StandardNormalSequence


class ARProcess(RandomVariable):
Expand All @@ -16,32 +21,23 @@ class ARProcess(RandomVariable):
an AR(p) process.
"""

def __init__(self, noise_rv_name: str, *args, **kwargs) -> None:
"""
Default constructor.
Parameters
----------
noise_rv_name : str
A name for the internal RandomVariable
holding the process noise.
"""
super().__init__(*args, **kwargs)
self.noise_rv_ = StandardNormalSequence(element_rv_name=noise_rv_name)

def sample(
self,
noise_name: str,
n: int,
autoreg: ArrayLike,
init_vals: ArrayLike,
noise_sd: float | ArrayLike,
**kwargs,
) -> ArrayLike:
"""
Sample from the AR process
Parameters
----------
noise_name: str
A name for the sample site holding the
Normal(`0`, `noise_sd`) noise for the AR process.
Passed to :func:`numpyro.sample`.
n: int
Length of the sequence.
autoreg: ArrayLike
Expand All @@ -52,60 +48,137 @@ def sample(
init_vals : ArrayLike
Array of initial values. Must have the
same first dimension size as the order.
noise_sd : float | ArrayLike
Scalar giving the s.d. of the AR
noise_sd : ArrayLike
Standard deviation of the AR
process Normal noise, which by
definition has mean 0.
**kwargs : dict, optional
Additional keyword arguments passed to
self.noise_rv_.sample()
Returns
-------
ArrayLike
with first dimension of length `n`
and additional dimensions as inferred
from the shapes of `autoreg`,
`init_vals`, and `noise_sd`.
Notes
-----
The first dimension of the return value
with be of length `n` and represents time.
Trailing dimensions follow standard numpy
broadcasting rules and are determined from
the second through `n` th dimensions, if any,
of `autoreg` and `init_vals`, as well as the
all dimensions of `noise_sd` (i.e.
:code:`jax.numpy.shape(autoreg)[1:]`,
:code:`jax.numpy.shape(init_vals)[1:]`
and :code:`jax.numpy.shape(noise_sd)`
Those shapes must be
broadcastable together via
:func:`jax.lax.broadcast_shapes`. This can
be used to produce multiple AR processes of the
same order but with either shared or different initial
values, AR coefficient vectors, and/or
and noise standard deviation values.
"""
noise_sd_arr = jnp.atleast_1d(noise_sd)
if not noise_sd_arr.shape == (1,):
raise ValueError("noise_sd must be a scalar. " f"Got {noise_sd}")
autoreg = jnp.atleast_1d(autoreg)
init_vals = jnp.atleast_1d(init_vals)
noise_sd = jnp.array(noise_sd)
# noise_sd can be a scalar, but
# autoreg and init_vals must have a
# a first dimension (time),
# as the order of the process is
# inferred from that first dimension

if not autoreg.ndim == 1:
raise ValueError(
"Array of autoregressive coefficients "
"must be no more than 1 dimension",
f"Got {autoreg.ndim}",
order = autoreg.shape[0]
n_inits = init_vals.shape[0]

try:
noise_shape = jax.lax.broadcast_shapes(
init_vals.shape[1:],
autoreg.shape[1:],
noise_sd.shape,
)
if not init_vals.ndim == 1:
except Exception as e:
raise ValueError(
"Array of initial values must be " "no more than 1 dimension",
f"Got {init_vals.ndim}",
)
order = autoreg.size
if not init_vals.size == order:
"Could not determine a "
"valid shape for the AR process noise "
"from the shapes of the init_vals, "
"autoreg, and noise_sd arrays. "
"See ARProcess.sample() documentation "
"for details."
) from e

if not n_inits == order:
raise ValueError(
"Array of initial values must be "
"be the same size as the order of "
"the autoregressive process, "
"which is determined by the number "
"of autoregressive coefficients "
"provided. Got {init_vals.size} "
"initial values for a process of "
f"order {order}"
"Initial values array must have the same "
"first dimension length as the order p of "
"the AR process. The order is given by "
"the first dimension length of the array "
"of autoregressive coefficients. Got an initial "
f"value array with first dimension {n_inits} for "
f"a process of order {order}"
)

raw_noise = self.noise_rv_(n=n, **kwargs)
noise = noise_sd_arr * raw_noise
history_shape = (order,) + noise_shape

try:
inits_broadcast = jnp.broadcast_to(init_vals, history_shape)
except Exception as e:
raise ValueError(
"Could not broadcast init_vals "
f"(shape {init_vals.shape}) "
"to the expected shape of the process "
f"history (shape {history_shape}). "
"History shape is determined by the "
"shapes of the init_vals, autoreg, and "
"noise_sd arrays. See ARProcess "
"documentation for details"
) from e

inits_flipped = jnp.flip(inits_broadcast, axis=0)

def transition(recent_vals, _): # numpydoc ignore=GL08
with numpyro.handlers.reparam(
config={noise_name: LocScaleReparam(0)}
):
next_noise = numpyro.sample(
noise_name,
numpyro.distributions.Normal(
loc=jnp.zeros(noise_shape), scale=noise_sd
),
)

dot_prod = jnp.einsum("i...,i...->...", autoreg, recent_vals)
new_term = dot_prod + next_noise
new_recent_vals = jnp.concatenate(
[
new_term[jnp.newaxis, ...],
# concatenate as (1 time unit,) + noise_shape
# array
recent_vals,
],
axis=0,
)[:order]

def transition(recent_vals, next_noise): # numpydoc ignore=GL08
new_term = jnp.dot(autoreg, recent_vals) + next_noise
new_recent_vals = jnp.hstack(
[new_term, recent_vals[: (order - 1)]]
)
return new_recent_vals, new_term

last, ts = scan(transition, init_vals, noise)
return jnp.hstack([init_vals, ts])
if n > order:
_, ts = scan(
f=transition,
init=inits_flipped,
xs=None,
length=(n - order),
)

ts_with_inits = jnp.concatenate(
[inits_broadcast, ts],
axis=0,
)
else:
ts_with_inits = inits_broadcast
return ts_with_inits[:n]

@staticmethod
def validate(): # numpydoc ignore=RT01
Expand Down
9 changes: 8 additions & 1 deletion pyrenew/process/iidrandomsequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class StandardNormalSequence(IIDRandomSequence):
def __init__(
self,
element_rv_name: str,
element_shape: tuple = None,
**kwargs,
):
"""
Expand All @@ -124,13 +125,19 @@ def __init__(
DistributionalVariable encoding a
standard Normal (mean = 0, sd = 1)
distribution.
element_shape : tuple
Shape for each element in the sequence.
If None, elements are scalars. Default
None.
Returns
-------
None
"""
if element_shape is None:
element_shape = ()
super().__init__(
element_rv=DistributionalVariable(
name=element_rv_name, distribution=dist.Normal(0, 1)
),
).expand_by(element_shape)
)
17 changes: 9 additions & 8 deletions pyrenew/process/rtperiodicdiffar.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def __init__(
self.log_rt_rv = log_rt_rv
self.autoreg_rv = autoreg_rv
self.periodic_diff_sd_rv = periodic_diff_sd_rv
self.ar_process_suffix = ar_process_suffix

self.ar_diff = DifferencedProcess(
fundamental_process=ARProcess(
noise_rv_name=f"{name}{ar_process_suffix}"
),
fundamental_process=ARProcess(),
differencing_order=1,
)

Expand Down Expand Up @@ -138,22 +138,23 @@ def sample(
"""

# Initial sample
log_rt_rv = self.log_rt_rv.sample(**kwargs)
b = self.autoreg_rv.sample(**kwargs)
s_r = self.periodic_diff_sd_rv.sample(**kwargs)
log_rt_rv = self.log_rt_rv(**kwargs).squeeze()
b = self.autoreg_rv(**kwargs).squeeze()
s_r = self.periodic_diff_sd_rv(**kwargs).squeeze()

# How many periods to sample?
n_periods = (duration + self.period_size - 1) // self.period_size

# Running the process

log_rt = self.ar_diff(
noise_name=f"{self.name}{self.ar_process_suffix}",
n=n_periods,
init_vals=jnp.array([log_rt_rv[0]]),
init_vals=jnp.array(log_rt_rv[0]),
autoreg=b,
noise_sd=s_r,
fundamental_process_init_vals=jnp.array(
[log_rt_rv[1] - log_rt_rv[0]]
log_rt_rv[1] - log_rt_rv[0]
),
)

Expand Down
4 changes: 2 additions & 2 deletions pyrenew/randomvariable/distributionalvariable.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def sample(
def expand_by(self, sample_shape) -> Self:
"""
Expand the distribution by a given
shape_shape, if possible. Returns a
sample_shape, if possible. Returns a
new DynamicDistributionalVariable whose underlying
distribution will be expanded by the given shape
at sample() time.
Expand Down Expand Up @@ -326,5 +326,5 @@ def DistributionalVariable(
"(for instantiating a static DistributionalVariable) "
"or a callable that returns a "
"numpyro.distributions.Distribution (for "
"a dynamic DistributionalVariable"
"a dynamic DistributionalVariable)."
)
Loading

0 comments on commit 3d06ab9

Please sign in to comment.