Skip to content

Commit

Permalink
Add var_names to save a subset of the variables
Browse files Browse the repository at this point in the history
  • Loading branch information
maurosilber authored and Armavica committed Jan 13, 2025
1 parent 45d47e1 commit 66320c2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 11 deletions.
16 changes: 15 additions & 1 deletion python/rebop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import xarray as xr

from .rebop import Gillespie, __version__ # type: ignore[attr-defined]

if TYPE_CHECKING:
from collections.abc import Sequence

__all__ = ("Gillespie", "__version__")

og_run = Gillespie.run
Expand All @@ -17,13 +22,22 @@ def run_xarray( # noqa: PLR0913 too many parameters in function definition
seed: int | None = None,
*,
sparse: bool = False,
var_names: Sequence[str] | None = None,
) -> xr.Dataset:
"""Run the system until `tmax` with `nb_steps` steps.
The initial configuration is specified in the dictionary `init`.
Returns an xarray Dataset.
"""
times, result = og_run(self, init, tmax, nb_steps, seed, sparse=sparse)
times, result = og_run(
self,
init,
tmax,
nb_steps,
seed,
sparse=sparse,
var_names=var_names,
)
ds = xr.Dataset(
data_vars={
name: xr.DataArray(values, dims="time", coords={"time": times})
Expand Down
3 changes: 3 additions & 0 deletions python/rebop/rebop.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Sequence

import xarray

class Gillespie:
Expand Down Expand Up @@ -32,6 +34,7 @@ class Gillespie:
seed: int | None = None,
*,
sparse: bool = False,
var_names: Sequence[str] | None = None,
) -> xarray.Dataset:
"""Run the system until `tmax` with `nb_steps` steps.
Expand Down
37 changes: 27 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,15 @@ impl Gillespie {
/// values at the given time points. One can specify a random `seed` for reproducibility.
/// If `nb_steps` is `0`, then returns all reactions, ending with the first that happens at
/// or after `tmax`.
#[pyo3(signature = (init, tmax, nb_steps, seed=None, sparse=false))]
#[pyo3(signature = (init, tmax, nb_steps, seed=None, sparse=false, var_names=None))]
fn run(
&self,
init: HashMap<String, usize>,
tmax: f64,
nb_steps: usize,
seed: Option<u64>,
sparse: bool,
var_names: Option<Vec<String>>,
) -> PyResult<(Vec<f64>, HashMap<String, Vec<isize>>)> {
let mut x0 = vec![0; self.species.len()];
for (name, &value) in &init {
Expand All @@ -322,6 +323,13 @@ impl Gillespie {
Some(seed) => gillespie::Gillespie::new_with_seed(x0, sparse, seed),
None => gillespie::Gillespie::new(x0, sparse),
};
let save_indices: Vec<_> = match &var_names {
Some(x) => x
.iter()
.map(|key| self.species.get(key).unwrap().clone())
.collect(),
None => (0..self.species.len()).collect(),
};

for (rate, reactants, products) in self.reactions.iter() {
let mut vreactants = vec![0; self.species.len()];
Expand All @@ -340,34 +348,43 @@ impl Gillespie {
}
let mut times = Vec::new();
// species.shape = (species, nb_steps)
let mut species = vec![Vec::new(); self.species.len()];
let mut species = vec![Vec::new(); save_indices.len()];
if nb_steps > 0 {
for i in 0..=nb_steps {
let t = tmax * i as f64 / nb_steps as f64;
times.push(t);
g.advance_until(t);
for s in 0..self.species.len() {
species[s].push(g.get_species(s));
for (i, s) in save_indices.iter().enumerate() {
species[i].push(g.get_species(*s));
}
}
} else {
// nb_steps = 0: we return every step
let mut rates = vec![f64::NAN; g.nb_reactions()];
times.push(g.get_time());
for s in 0..self.species.len() {
species[s].push(g.get_species(s));
for (i, s) in save_indices.iter().enumerate() {
species[i].push(g.get_species(*s));
}
while g.get_time() < tmax {
g._advance_one_reaction(&mut rates);
times.push(g.get_time());
for s in 0..self.species.len() {
species[s].push(g.get_species(s));
for (i, s) in save_indices.iter().enumerate() {
species[i].push(g.get_species(*s));
}
}
}
let mut result = HashMap::new();
for (name, &id) in &self.species {
result.insert(name.clone(), species[id].clone());
match var_names {
Some(x) => {
for (id, name) in x.iter().enumerate() {
result.insert(name.clone(), species[id].clone());
}
}
None => {
for (name, &id) in &self.species {
result.insert(name.clone(), species[id].clone());
}
}
}
Ok((times, result))
}
Expand Down
21 changes: 21 additions & 0 deletions tests/test_rebop.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,24 @@ def test_dense_vs_sparse() -> None:
ds_dense = sir.run(init, **kwargs, sparse=False)
ds_sparse = sir.run(init, **kwargs, sparse=True)
assert (ds_dense == ds_sparse).all()


@pytest.mark.parametrize("nb_steps", [0, 250])
def test_var_names(nb_steps: int) -> None:
all_variables = {"S", "I", "R"}
subset_to_save = ["S", "I"]
remaining = all_variables.difference(subset_to_save)

sir = sir_model()
init = {"S": 999, "I": 1}
kwargs = {"tmax": 250, "nb_steps": nb_steps, "seed": 0}

ds_all = sir.run(init, **kwargs, var_names=None)
ds_subset = sir.run(init, **kwargs, var_names=subset_to_save)

for s in subset_to_save:
assert s in ds_subset
for s in remaining:
assert s not in ds_subset

assert ds_all[subset_to_save] == ds_subset

0 comments on commit 66320c2

Please sign in to comment.