Skip to content

Commit

Permalink
code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Sep 17, 2024
1 parent 542ad84 commit bafe3b0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 80 deletions.
2 changes: 1 addition & 1 deletion pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def sample(
)

if inf_feedback_strength.ndim == Rt.ndim - 1:
inf_feedback_strength = jnp.expand_dims(inf_feedback_strength, 0)
inf_feedback_strength = inf_feedback_strength[jnp.newaxis]

# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.shape[0] == 1:
Expand Down
79 changes: 0 additions & 79 deletions test/test_infection_and_infectionwithfeedback.py

This file was deleted.

67 changes: 67 additions & 0 deletions test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,70 @@ def test_infectionsrtfeedback_feedback(Rt, I0):
assert_array_almost_equal(samp1.rt, res["rt"])

return None


def test_infections_with_feedback_invalid_inputs():
"""
Test the InfectionsWithFeedback class cannot
be sampled when Rt and I0 have invalid input shapes
"""
I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8])
I0_2d = jnp.array(
np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3)
).reshape((7, -1))
Rt = jnp.ones(10)
gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0])

inf_feed_strength = DeterministicVariable(
name="inf_feed_strength", value=0.5
)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int)

# Test the InfectionsWithFeedback class
InfectionsWithFeedback = latent.InfectionsWithFeedback(
infection_feedback_strength=inf_feed_strength,
infection_feedback_pmf=inf_feedback_pmf,
)

infections = latent.Infections()

with numpyro.handlers.seed(rng_seed=0):
with pytest.raises(
ValueError,
match="Initial infections must be at least as long as the generation interval.",
):
InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
I0=I0_1d,
)

with pytest.raises(
ValueError,
match="Initial infections vector must be at least as long as the generation interval.",
):
infections(
gen_int=gen_int,
Rt=Rt,
I0=I0_1d,
)

with pytest.raises(
ValueError,
match="Initial infections and Rt must have the same batch shapes.",
):
InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
I0=I0_2d,
)

with pytest.raises(
ValueError,
match="Initial infections and Rt must have the same batch shapes.",
):
infections(
gen_int=gen_int,
Rt=Rt,
I0=I0_2d,
)

0 comments on commit bafe3b0

Please sign in to comment.