Skip to content

Commit

Permalink
Add seed parameter to BaseNetwork sample function
Browse files Browse the repository at this point in the history
Allows to generate replicable samples
  • Loading branch information
anton-golubkov committed Jun 6, 2024
1 parent 75b370f commit d2cd675
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,17 +638,21 @@ def sample(
predict: bool = False,
parall_count: int = 1,
filter_neg: bool = True,
seed: Optional[int] = None,
) -> Union[None, pd.DataFrame, List[Dict[str, Union[str, int, float]]]]:
"""
Sampling from Bayesian Network
n: int number of samples
evidence: values for nodes from user
parall_count: number of threads. Defaults to 1.
filter_neg: either filter negative vals or not.
seed: seed value to use for random number generator
"""
from joblib import Parallel, delayed

random.seed()
random.seed(seed)
np.random.seed(seed)

if not self.distributions.items():
logger_network.error(
"Parameter learning wasn't done. Call fit_parameters method"
Expand Down

0 comments on commit d2cd675

Please sign in to comment.