Skip to content

Commit

Permalink
Merge pull request #209 from pph-collective/poisson-update
Browse files Browse the repository at this point in the history
Poisson update
  • Loading branch information
s-bessey authored Dec 21, 2021
2 parents bbe820e + ad3e734 commit 4bf3b6a
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 17 deletions.
9 changes: 9 additions & 0 deletions titan/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ def weibull_modified(np_random, shape, scale):
"""
random_number = np_random.random()
return scale * (-log(1 - random_number)) ** (1 / shape)


def poisson(np_rand, mu: float):
"""
Mirrors scipy poisson.rvs function as used in code
"""
if mu < 0:
return 0
return np_rand.poisson(mu)
5 changes: 3 additions & 2 deletions titan/features/high_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .. import agent
from .. import population
from .. import model
from ..distributions import poisson


class HighRisk(base_feature.BaseFeature):
Expand Down Expand Up @@ -170,7 +171,7 @@ def update_partner_numbers(self, pop: "population.Population", amount: int):
"""
for bond in self.agent.location.params.high_risk.partnership_types:
self.agent.mean_num_partners[bond] += amount # could be negative
self.agent.target_partners[bond] = utils.poisson(
self.agent.mean_num_partners[bond], pop.np_random
self.agent.target_partners[bond] = poisson(
pop.np_random, self.agent.mean_num_partners[bond]
)
pop.update_partnerability(self.agent)
4 changes: 2 additions & 2 deletions titan/interactions/injection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from . import base_interaction
from .. import utils
from .. import features
from .. import model
from .. import agent
from ..distributions import poisson


class Injection(base_interaction.BaseInteraction):
Expand Down Expand Up @@ -39,7 +39,7 @@ def get_num_acts(cls, model: "model.TITAN", rel: "agent.Relationship") -> int:
min(agent_params.num_acts, partner_params.num_acts)
* model.calibration.injection.act
)
share_acts = utils.poisson(mean_num_acts, model.np_random)
share_acts = poisson(model.np_random, mean_num_acts)

if share_acts < 1:
return 0
Expand Down
4 changes: 2 additions & 2 deletions titan/interactions/sex.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import base_interaction
from .. import utils
from ..distributions import poisson
from .. import model
from .. import agent

Expand All @@ -21,7 +21,7 @@ def get_num_acts(cls, model: "model.TITAN", rel: "agent.Relationship") -> int:
mean_sex_acts = (
rel.get_number_of_sex_acts(model.np_random) * model.calibration.sex.act
)
total_sex_acts = utils.poisson(mean_sex_acts, model.np_random)
total_sex_acts = poisson(model.np_random, mean_sex_acts)

# Get condom usage
p_safe_sex = (
Expand Down
5 changes: 3 additions & 2 deletions titan/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from . import utils
from . import features
from . import exposures
from .distributions import poisson


class Population:
Expand Down Expand Up @@ -434,8 +435,8 @@ def update_partner_targets(self):
"""
for a in self.all_agents:
for bond in self.params.classes.bond_types:
a.target_partners[bond] = utils.poisson(
a.mean_num_partners[bond], self.np_random
a.target_partners[bond] = poisson(
self.np_random, a.mean_num_partners[bond]
)
self.update_partnerability(a)

Expand Down
11 changes: 2 additions & 9 deletions titan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def parse_var(dist_value, dist_type):
def get_dist(rand_gen, dist_type):
if dist_type == "randint":
return lambda *args: safe_random_int(*args, rand_gen)
elif hasattr(rand_gen, dist_type):
return getattr(rand_gen, dist_type)
elif hasattr(distributions, dist_type):
dist = getattr(distributions, dist_type)
return lambda *args: dist(rand_gen, *args)
elif hasattr(rand_gen, dist_type):
return getattr(rand_gen, dist_type)
else:
raise AttributeError(f"Distribution type {dist_type} not found!")

Expand Down Expand Up @@ -184,13 +184,6 @@ def binom_0(n: int, p: float):
return (1 - p) ** n


def poisson(mu: float, np_rand):
"""
Mirrors scipy poisson.rvs function as used in code
"""
return np_rand.poisson(mu)


def get_param_from_path(params: ObjMap, param_path: str, delimiter: str):
"""
Given a params object and a delimited path, get the leaf of the params tree
Expand Down

0 comments on commit 4bf3b6a

Please sign in to comment.