Skip to content

Commit

Permalink
added vectorized choose to discrete node
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Dec 29, 2023
1 parent 689f164 commit ed82b40
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions bamt/nodes/discrete_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,59 @@ def choose(self, node_info: Dict[str, Union[float, str]], pvals: List[str]) -> s

return vals[rindex]

@staticmethod
def searchsorted_per_row(row, rand_num):
return np.searchsorted(row, rand_num, side="right")

def vectorized_choose(
self,
node_info: Dict[str, Union[float, str]],
pvals_array: np.ndarray,
n_samples: int,
) -> np.ndarray:
"""
Vectorized method to return values from a discrete node.
params:
node_info: node's info from distributions
pvals_array: array of parent values, each row corresponds to a set of parent values
n_samples: number of samples to generate
"""
vals = node_info["vals"]

# Generate a matrix of distributions
if pvals_array is None or len(pvals_array) == 0:
# Handle the case with no parent nodes
dist = np.array(node_info["cprob"])
dist_matrix = np.tile(dist, (n_samples, 1))
else:
# Ensure pvals_array is limited to current batch size
pvals_array = pvals_array[:n_samples]
# Compute distribution for each set of parent values
dist_matrix = np.array(
[self.get_dist(node_info, pvals.tolist()) for pvals in pvals_array]
)

# Ensure that dist_matrix is 2D
if dist_matrix.ndim == 1:
dist_matrix = dist_matrix.reshape(1, -1)

# Generate cumulative distributions
cumulative_dist_matrix = np.cumsum(dist_matrix, axis=1)

random_nums = np.random.rand(n_samples)

# Apply searchsorted across each row
indices = np.apply_along_axis(
self.searchsorted_per_row, 1, cumulative_dist_matrix, random_nums
)

if indices.ndim > 1:
indices = indices.flatten()

sampled_values = np.array(vals)[indices]

return sampled_values

@staticmethod
def predict(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str:
"""function for prediction based on evidence values in discrete node
Expand Down

0 comments on commit ed82b40

Please sign in to comment.