Skip to content

Commit

Permalink
Model selection (#63)
Browse files Browse the repository at this point in the history
* 1. Selection of model for predict method.
2. Subset max shape changed to 2 to make some models working.

* Update base.py

Tests error fixed
  • Loading branch information
Roman223 authored Jun 27, 2023
1 parent 25ba073 commit 45c33e0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
11 changes: 6 additions & 5 deletions bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,8 @@ def wrapper():
def predict(self,
test: pd.DataFrame,
parall_count: int = 1,
progress_bar: bool = True) -> Dict[str,
progress_bar: bool = True,
models_dir: Optional[str] = None) -> Dict[str,
Union[List[str],
List[int],
List[float]]]:
Expand All @@ -630,7 +631,7 @@ def predict(self,

from joblib import Parallel, delayed

def wrapper(bn, test: pd.DataFrame, columns: List[str]):
def wrapper(bn, test: pd.DataFrame, columns: List[str], models_dir: str):
preds = {column_name: list() for column_name in columns}

if len(test) == 1:
Expand All @@ -639,7 +640,7 @@ def wrapper(bn, test: pd.DataFrame, columns: List[str]):
for n, key in enumerate(columns):
try:
sample = bn.sample(
1, evidence=test_row, predict=True, progress_bar=False)
1, evidence=test_row, predict=True, progress_bar=False, models_dir=models_dir)
if sample.empty:
preds[key].append(np.nan)
continue
Expand Down Expand Up @@ -670,10 +671,10 @@ def wrapper(bn, test: pd.DataFrame, columns: List[str]):

if progress_bar:
processed_list = Parallel(n_jobs=parall_count)(delayed(wrapper)(
self, test.loc[[i]], columns) for i in tqdm(test.index, position=0, leave=True))
self, test.loc[[i]], columns, models_dir) for i in tqdm(test.index, position=0, leave=True))
else:
processed_list = Parallel(n_jobs=parall_count)(
delayed(wrapper)(self, test.loc[[i]], columns) for i in test.index)
delayed(wrapper)(self, test.loc[[i]], columns, models_dir) for i in test.index)

for i in range(test.shape[0]):
curr_pred = processed_list[i]
Expand Down
2 changes: 1 addition & 1 deletion bamt/nodes/conditional_gaussian_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def fit_parameters(
mask = (mask) & (data[col] == val)
new_data = data[mask]
key_comb = [str(x) for x in comb]
if new_data.shape[0] > 0:
if new_data.shape[0] > 1:
if self.cont_parents:
model = self.regressor
model.fit(new_data[self.cont_parents].values,
Expand Down
1 change: 0 additions & 1 deletion tests/sendingRegressors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# import json

# import bamt.networks as Nets
from bamt.networks.hybrid_bn import HybridBN
import bamt.preprocessors as preprocessors

Expand Down

0 comments on commit 45c33e0

Please sign in to comment.