Skip to content

Commit

Permalink
swapped to dill to allow lambdas
Browse files Browse the repository at this point in the history
  • Loading branch information
CeSchmitz committed Feb 20, 2025
1 parent 8105754 commit d907632
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions CRISPR_DeepEnsemble.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
import pickle
import dill as pickle
import shutil
import tempfile
from tqdm import tqdm
Expand Down Expand Up @@ -73,7 +73,7 @@ class RegressionDeepEnsemble:
"""

def __init__(self, BaseNet: torch.nn.Module=CRISPRnet, dataset='', n_estimators=5,
batch_size=25, response_var = torch.distributions.Normal, trans_func=None,
batch_size=25, response_var = torch.distributions.Normal, trans_func=lambda x : x,
load_from=None):

"""Deep Ensembles for Univariate Regression
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(self, BaseNet: torch.nn.Module=CRISPRnet, dataset='', n_estimators=
self.n_estimators = n_estimators
self.loader = [DataLoader(dataset = dataset, batch_size=batch_size, shuffle=True) for k in range(n_estimators)]
self.response_var = response_var
self.trans_func = self._default_trans_func if trans_func is None else trans_func
self.trans_func = trans_func

def _default_trans_func(self, x):
return x
Expand Down

0 comments on commit d907632

Please sign in to comment.