Skip to content

Commit

Permalink
Refactoring the periodic broadcaster and added warn + error + test (#349
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gvegayon authored Aug 7, 2024
1 parent a131fe2 commit bf09a37
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 151 deletions.
2 changes: 1 addition & 1 deletion docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ for i in range(0, 30, 7):
plt.show()
```

The implementation of the `RtWeeklyDiffProcess` (which is an instance of `RtPeriodicDiffProcess`), uses `PeriodicBroadcaster` to repeating values: `PeriodicBroadcaster(..., period_size=7, broadcast_type="repeat")`. Setting the `broadcast_type` to `"repeat"` repeats each vector element for the specified period size. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.
The implementation of the `RtWeeklyDiffProcess` (which is an instance of `RtPeriodicDiffProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.

## Repeated sequences (tiling)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The `t_unit, t_start` pair can encode different types of time series data. For e

## How it relates to periodicity

The `PeriodicBroadcaster()` class provides a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect()` and `RtPeriodicDiffProcess()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.
The `tile_until_n()` and `repeat_until_n()` functions provide a way of tiling and repeating data accounting starting time, but they do not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect()` and `RtPeriodicDiffProcess()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.

## Unimplemented features

Expand Down
244 changes: 115 additions & 129 deletions model/src/pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,138 +112,124 @@ def __repr__(self):
return f"PeriodicProcessSample(value={self.value})"


class PeriodicBroadcaster:
r"""
Broadcast arrays periodically using either repeat or tile,
considering period size and starting point.
def tile_until_n(
data: ArrayLike,
n_timepoints: int,
offset: int = 0,
) -> ArrayLike:
"""
Tile the data until it reaches `n_timepoints`.
def __init__(
self,
offset: int,
period_size: int,
broadcast_type: str,
) -> None:
"""
Default constructor for PeriodicBroadcaster class.
Parameters
----------
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
period_size : int
Size of the period.
broadcast_type : str
Type of broadcasting to use, either "repeat" or "tile".
Notes
-----
See the sample method for more information on the broadcasting types.
Returns
-------
None
"""

self.validate(
offset=offset,
period_size=period_size,
broadcast_type=broadcast_type,
)
Parameters
----------
data : ArrayLike
Data to broadcast.
n_timepoints : int
Duration of the sequence.
offset : int
Relative point at which data starts, must be a non-negative integer.
Default is zero, i.e., no offset.
Notes
-----
Using the `offset` parameter, the function will start the broadcast
from the `offset`-th element of the data. If the data is shorter than
`n_timepoints`, the function will tile the data until it
reaches `n_timepoints`.
self.period_size = period_size
self.offset = offset
self.broadcast_type = broadcast_type

return None

@staticmethod
def validate(offset: int, period_size: int, broadcast_type: str) -> None:
"""
Validate the input parameters.
Parameters
----------
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
period_size : int
Size of the period.
broadcast_type : str
Type of broadcasting to use, either "repeat" or "tile".
Returns
-------
None
"""

# Period size should be a positive integer
assert isinstance(
period_size, int
), f"period_size should be an integer. It is {type(period_size)}."

assert (
period_size > 0
), f"period_size should be a positive integer. It is {period_size}."

# Data starts should be a positive integer
assert isinstance(
offset, int
), f"offset should be an integer. It is {type(offset)}."

assert (
0 <= offset
), f"offset should be a positive integer. It is {offset}."

assert offset <= period_size - 1, (
"offset should be less than or equal to period_size - 1."
f"It is {offset}. It should be less than or equal "
f"to {period_size - 1}."
)
Returns
-------
ArrayLike
Tiled data.
"""

# Data starts should be a positive integer
assert isinstance(
offset, int
), f"offset should be an integer. It is {type(offset)}."

assert 0 <= offset, f"offset should be a positive integer. It is {offset}."

return jnp.tile(data, (n_timepoints // data.size) + 1)[
offset : (offset + n_timepoints)
]


def repeat_until_n(
data: ArrayLike,
period_size: int,
n_timepoints: int,
offset: int = 0,
):
"""
Repeat each entry in `data` a given number of times (`period_size`)
until an array of length `n_timepoints` has been produced.
Notes
-----
Using the `offset` parameter, the function will offset the data after
the repeat operation. So, if the offset is 2, the repeat operation
will repeat the data until `n_timepoints + 2` and then offset the
data by 2, returning an array of size `n_timepoints`. This is a way to start
part-way into a period. For example, the following code will each array
element four times until 10 timepoints and then offset the data by 2:
.. code-block:: python
data = jnp.array([1, 2, 3])
repeat_until_n(data, 4, 10, 2)
# Array([1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=int32)
# Broadcast type should be either "repeat" or "tile"
assert broadcast_type in ["repeat", "tile"], (
"broadcast_type should be either 'repeat' or 'tile'. "
f"It is {broadcast_type}."
Parameters
----------
data : ArrayLike
Data to broadcast.
period_size : int
Size of the period for the repeat broadcast.
n_timepoints : int
Duration of the sequence.
offset : int, optional
Relative point at which data starts, must be between 0 and
period_size - 1. By default 0, i.e., no offset.
Returns
-------
ArrayLike
Repeated data.
"""

# Data starts should be a positive integer
assert isinstance(
offset, int
), f"offset should be an integer. It is {type(offset)}."

assert 0 <= offset, f"offset should be a positive integer. It is {offset}."

# Period size should be a positive integer
assert isinstance(
period_size, int
), f"period_size should be an integer. It is {type(period_size)}."

assert (
period_size > 0
), f"period_size should be a positive integer. It is {period_size}."

assert offset <= period_size - 1, (
"offset should be less than or equal to period_size - 1."
f"It is {offset}. It should be less than or equal "
f"to {period_size - 1}."
)

if (data.size * period_size) < (n_timepoints + offset):
raise ValueError(
"The data is too short to broadcast to the given number "
f"of timepoints + offset ({n_timepoints + offset}). The "
"repeated data would have a size of data.size * "
f"period_size = {data.size} * {period_size} = "
f"{data.size * period_size}."
)

return None

def __call__(
self,
data: ArrayLike,
n_timepoints: int,
) -> ArrayLike:
"""
Broadcast the data to the given number of timepoints
considering the period size and starting point.
Parameters
----------
data: ArrayLike
Data to broadcast.
n_timepoints : int
Duration of the sequence.
Notes
-----
The broadcasting is done by repeating or tiling the data. When
self.broadcast_type = "repeat", the function will repeat each value of the data `self.period_size` times until it reaches `n_timepoints`. When self.broadcast_type = "tile", the function will tile the data until it reaches `n_timepoints`.
Using the `offset` parameter, the function will start the broadcast from the `offset`-th element of the data. If the data is shorter than `n_timepoints`, the function will repeat or tile the data until it reaches `n_timepoints`.
Returns
-------
ArrayLike
Broadcasted array.
"""

if self.broadcast_type == "repeat":
return jnp.repeat(data, self.period_size)[
self.offset : (self.offset + n_timepoints)
]
elif self.broadcast_type == "tile":
return jnp.tile(
data, int(jnp.ceil(n_timepoints / self.period_size))
)[self.offset : (self.offset + n_timepoints)]
return jnp.repeat(
a=data,
repeats=period_size,
total_repeat_length=n_timepoints + offset,
)[offset : (offset + n_timepoints)]
13 changes: 3 additions & 10 deletions model/src/pyrenew/process/periodiceffect.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class PeriodicEffect(RandomVariable):
def __init__(
self,
offset: int,
period_size: int,
quantity_to_broadcast: RandomVariable,
t_start: int,
t_unit: int,
Expand All @@ -48,8 +47,6 @@ def __init__(
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
period_size : int
Size of the period.
quantity_to_broadcast : RandomVariable
Values to be broadcasted (repeated or tiled).
t_start : int
Expand All @@ -64,11 +61,7 @@ def __init__(

PeriodicEffect.validate(quantity_to_broadcast)

self.broadcaster = au.PeriodicBroadcaster(
offset=offset,
period_size=period_size,
broadcast_type="tile",
)
self.offset = offset

self.set_timeseries(
t_start=t_start,
Expand Down Expand Up @@ -114,9 +107,10 @@ def sample(self, duration: int, **kwargs):

return PeriodicEffectSample(
value=SampledValue(
self.broadcaster(
au.tile_until_n(
data=self.quantity_to_broadcast.sample(**kwargs)[0].value,
n_timepoints=duration,
offset=self.offset,
),
t_start=self.t_start,
t_unit=self.t_unit,
Expand Down Expand Up @@ -157,7 +151,6 @@ def __init__(

super().__init__(
offset=offset,
period_size=7,
quantity_to_broadcast=quantity_to_broadcast,
t_start=t_start,
t_unit=1,
Expand Down
16 changes: 8 additions & 8 deletions model/src/pyrenew/process/rtperiodicdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import NamedTuple

import jax.numpy as jnp
import pyrenew.arrayutils as au
from jax.typing import ArrayLike
from pyrenew.arrayutils import PeriodicBroadcaster
from pyrenew.metaclass import (
RandomVariable,
SampledValue,
Expand Down Expand Up @@ -77,19 +77,14 @@ def __init__(
-------
None
"""
self.name = name
self.broadcaster = PeriodicBroadcaster(
offset=offset,
period_size=period_size,
broadcast_type="repeat",
)

self.validate(
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)

self.name = name
self.period_size = period_size
self.offset = offset
self.log_rt_rv = log_rt_rv
Expand Down Expand Up @@ -192,7 +187,12 @@ def sample(

return RtPeriodicDiffProcessSample(
rt=SampledValue(
self.broadcaster(jnp.exp(log_rt.value.flatten()), duration),
au.repeat_until_n(
data=jnp.exp(log_rt.value.flatten()),
n_timepoints=duration,
offset=self.offset,
period_size=self.period_size,
),
t_start=self.t_start,
t_unit=self.t_unit,
),
Expand Down
Loading

0 comments on commit bf09a37

Please sign in to comment.