From eb5516433141bfd75c2aa09459cdac24d89cd723 Mon Sep 17 00:00:00 2001 From: Roman Netrogolov <68499591+Roman223@users.noreply.github.com> Date: Thu, 30 Nov 2023 17:01:26 +0300 Subject: [PATCH] Bug with bn after saving (#91) * serialization fix * bug fix * mixtures were excluded from list of serizalization & bug with extensions fixed --- bamt/networks/base.py | 15 ++++++++++----- tests/sendingRegressors.py | 8 ++++---- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/bamt/networks/base.py b/bamt/networks/base.py index 09ddc7d..9d67178 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -2,6 +2,7 @@ import os.path as path import random import re +from copy import deepcopy from typing import Dict, Tuple, List, Callable, Optional, Type, Union, Any, Sequence import numpy as np @@ -457,21 +458,25 @@ def save(self, bn_name, models_dir: str = "models_dir"): :return: saving status. """ - distributions = self.distributions.copy() + distributions = deepcopy(self.distributions) new_weights = {str(key): self.weights[key] for key in self.weights} to_serialize = {} # separate logit and gaussian nodes from distributions to serialize bn's models - for node_name in self.distributions.keys(): + for node_name in distributions.keys(): + if "Mixture" in self[node_name].type: + continue + if self[node_name].type.startswith("Gaussian"): + if not distributions[node_name]["regressor"]: + continue if ( "Gaussian" in self[node_name].type or "Logit" in self[node_name].type or "ConditionalLogit" in self[node_name].type - or "ConditionalGaussian" in self[node_name].type ): to_serialize[node_name] = [ self[node_name].type, - self.distributions[node_name], + distributions[node_name], ] serializer = serialization_utils.ModelsSerializer( @@ -488,7 +493,7 @@ def save(self, bn_name, models_dir: str = "models_dir"): "parameters": distributions, "weights": new_weights, } - return self._save_to_file(bn_name, outdict) + return self._save_to_file(f"{bn_name}.json", outdict) def load(self, input_dir: str, models_dir: str = "/"): """ diff --git a/tests/sendingRegressors.py b/tests/sendingRegressors.py index 07f5841..50a6bca 100644 --- a/tests/sendingRegressors.py +++ b/tests/sendingRegressors.py @@ -23,7 +23,7 @@ "Permeability", "Depth", ] -] +].dropna() encoder = pp.LabelEncoder() discretizer = pp.KBinsDiscretizer(n_bins=5, encode="ordinal", strategy="quantile") @@ -32,9 +32,9 @@ discretized_data, est = p.apply(hack_data) -bn = HybridBN() +bn = HybridBN(has_logit=True) info = p.info -# + # with open(r"C:\Users\Roman\Desktop\mymodels\mynet.json") as f: # net_data = json.load(f) @@ -56,7 +56,7 @@ bn.fit_parameters(hack_data) -# bn.save("mynet.json") +# bn.save("mynet") print(bn.sample(100).shape) bn.get_info(as_df=False)