diff --git a/CRISPR_DeepEnsemble.py b/CRISPR_DeepEnsemble.py index 5ab4762..4d4ac4c 100644 --- a/CRISPR_DeepEnsemble.py +++ b/CRISPR_DeepEnsemble.py @@ -1,5 +1,5 @@ import torch -import pickle +import dill as pickle import shutil import tempfile from tqdm import tqdm @@ -73,7 +73,7 @@ class RegressionDeepEnsemble: """ def __init__(self, BaseNet: torch.nn.Module=CRISPRnet, dataset='', n_estimators=5, - batch_size=25, response_var = torch.distributions.Normal, trans_func=None, + batch_size=25, response_var = torch.distributions.Normal, trans_func=lambda x : x, load_from=None): """Deep Ensembles for Univariate Regression @@ -107,7 +107,7 @@ def __init__(self, BaseNet: torch.nn.Module=CRISPRnet, dataset='', n_estimators= self.n_estimators = n_estimators self.loader = [DataLoader(dataset = dataset, batch_size=batch_size, shuffle=True) for k in range(n_estimators)] self.response_var = response_var - self.trans_func = self._default_trans_func if trans_func is None else trans_func + self.trans_func = trans_func def _default_trans_func(self, x): return x