Skip to content

Commit

Permalink
numpy version
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Nov 1, 2023
1 parent 6c9bd27 commit 910e6f7
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions bamt/nodes/discrete_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,13 @@ def choose(self, node_info: Dict[str, Union[float, str]], pvals: List[str]) -> s
node_info: nodes info from distributions
pvals: parent values
"""
# NUMPY VERSION DO NOT DELETE
# vals = node_info["vals"]
# dist = np.array(self.get_dist(node_info, pvals))
#
# cumulative_dist = np.cumsum(dist)
#
# rand = np.random.random()
# rindex = np.searchsorted(cumulative_dist, rand)
#
# return vals[rindex]

vals = node_info["vals"]
dist = self.get_dist(node_info, pvals)
dist = np.array(self.get_dist(node_info, pvals))

cumulative_dist = list(accumulate(dist))
cumulative_dist = np.cumsum(dist)

rand = random.random()
rindex = next((i for i, ubound in enumerate(cumulative_dist) if rand < ubound), len(vals) - 1)
rand = np.random.random()
rindex = np.searchsorted(cumulative_dist, rand)

return vals[rindex]

Expand Down

0 comments on commit 910e6f7

Please sign in to comment.