Skip to content

Commit

Permalink
Bug with bn after saving (aimclub#91)
Browse files Browse the repository at this point in the history
* serialization fix

* bug fix

* mixtures were excluded from list of serizalization & bug with extensions fixed
  • Loading branch information
Roman223 authored Nov 30, 2023
1 parent 1315151 commit eb55164
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
15 changes: 10 additions & 5 deletions bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os.path as path
import random
import re
from copy import deepcopy
from typing import Dict, Tuple, List, Callable, Optional, Type, Union, Any, Sequence

import numpy as np
Expand Down Expand Up @@ -457,21 +458,25 @@ def save(self, bn_name, models_dir: str = "models_dir"):
:return: saving status.
"""
distributions = self.distributions.copy()
distributions = deepcopy(self.distributions)
new_weights = {str(key): self.weights[key] for key in self.weights}

to_serialize = {}
# separate logit and gaussian nodes from distributions to serialize bn's models
for node_name in self.distributions.keys():
for node_name in distributions.keys():
if "Mixture" in self[node_name].type:
continue
if self[node_name].type.startswith("Gaussian"):
if not distributions[node_name]["regressor"]:
continue
if (
"Gaussian" in self[node_name].type
or "Logit" in self[node_name].type
or "ConditionalLogit" in self[node_name].type
or "ConditionalGaussian" in self[node_name].type
):
to_serialize[node_name] = [
self[node_name].type,
self.distributions[node_name],
distributions[node_name],
]

serializer = serialization_utils.ModelsSerializer(
Expand All @@ -488,7 +493,7 @@ def save(self, bn_name, models_dir: str = "models_dir"):
"parameters": distributions,
"weights": new_weights,
}
return self._save_to_file(bn_name, outdict)
return self._save_to_file(f"{bn_name}.json", outdict)

def load(self, input_dir: str, models_dir: str = "/"):
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/sendingRegressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"Permeability",
"Depth",
]
]
].dropna()

encoder = pp.LabelEncoder()
discretizer = pp.KBinsDiscretizer(n_bins=5, encode="ordinal", strategy="quantile")
Expand All @@ -32,9 +32,9 @@

discretized_data, est = p.apply(hack_data)

bn = HybridBN()
bn = HybridBN(has_logit=True)
info = p.info
#

# with open(r"C:\Users\Roman\Desktop\mymodels\mynet.json") as f:
# net_data = json.load(f)

Expand All @@ -56,7 +56,7 @@

bn.fit_parameters(hack_data)

# bn.save("mynet.json")
# bn.save("mynet")

print(bn.sample(100).shape)
bn.get_info(as_df=False)

0 comments on commit eb55164

Please sign in to comment.