Skip to content

Commit

Permalink
Variable refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
CeSchmitz committed Feb 17, 2025
1 parent fedcf84 commit 8105754
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions CRISPR_DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class RegressionDeepEnsemble:

def __init__(self, BaseNet: torch.nn.Module=CRISPRnet, dataset='', n_estimators=5,
batch_size=25, response_var = torch.distributions.Normal, trans_func=None,
load_from_pickle=None):
load_from=None):

"""Deep Ensembles for Univariate Regression
Expand All @@ -100,8 +100,8 @@ def __init__(self, BaseNet: torch.nn.Module=CRISPRnet, dataset='', n_estimators=
Defaults to identity function.
load_from_pickle (path_like): Path of the zip file to load model from
"""
if load_from_pickle:
self.load(load_from_pickle)
if load_from:
self.load(load_from)
else:
self.ensemble = [BaseNet() for k in range(n_estimators)]
self.n_estimators = n_estimators
Expand Down

0 comments on commit 8105754

Please sign in to comment.