Skip to content

Commit

Permalink
new choose method for discrete node
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Nov 1, 2023
1 parent 82dd79e commit 6c9bd27
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions bamt/nodes/discrete_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import random
from itertools import product
from itertools import product, accumulate
from typing import Type, Dict, Union, List

import numpy as np
from pandas import DataFrame, crosstab

from .base import BaseNode
Expand Down Expand Up @@ -73,22 +74,24 @@ def choose(self, node_info: Dict[str, Union[float, str]], pvals: List[str]) -> s
node_info: nodes info from distributions
pvals: parent values
"""
rindex = 0
random.seed()
vals = node_info["vals"]
# NUMPY VERSION DO NOT DELETE
# vals = node_info["vals"]
# dist = np.array(self.get_dist(node_info, pvals))
#
# cumulative_dist = np.cumsum(dist)
#
# rand = np.random.random()
# rindex = np.searchsorted(cumulative_dist, rand)
#
# return vals[rindex]

vals = node_info["vals"]
dist = self.get_dist(node_info, pvals)

lbound = 0
ubound = 0
cumulative_dist = list(accumulate(dist))

rand = random.random()
for interval in range(len(dist)):
ubound += dist[interval]
if lbound <= rand < ubound:
rindex = interval
break
else:
lbound = ubound
rindex = next((i for i, ubound in enumerate(cumulative_dist) if rand < ubound), len(vals) - 1)

return vals[rindex]

Expand Down

0 comments on commit 6c9bd27

Please sign in to comment.