Skip to content

Commit

Permalink
[fix] reset set_same_seed position in TrainTPL
Browse files Browse the repository at this point in the history
  • Loading branch information
kervias committed Dec 18, 2023
1 parent 18d7312 commit ba6a481
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion edustudio/traintpl/edu_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def one_fold_start(self, fold_id):
test_loader=self.test_loader
)
# train
set_same_seeds(self.traintpl_cfg['seed'])
#set_same_seeds(self.traintpl_cfg['seed'])
if self.valid_loader is not None:
self.fit(train_loader=self.train_loader, valid_loader=self.valid_loader)
else:
Expand Down
2 changes: 1 addition & 1 deletion edustudio/traintpl/gd_traintpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def start(self):

extra_data = self.datatpl.get_extra_data()

set_same_seeds(self.traintpl_cfg['seed'])
self.model = self.get_model_obj()
self.model.build_cfg()
self.model.add_extra_data(**extra_data)
Expand Down Expand Up @@ -99,7 +100,6 @@ def one_fold_start(self, fold_id):
fold_id (int): fold id
"""
self.logger.info(f"====== [FOLD ID]: {fold_id} ======")
set_same_seeds(self.traintpl_cfg['seed'])

def batch_dict2device(self, batch_dict):
dic = {}
Expand Down

0 comments on commit ba6a481

Please sign in to comment.