Skip to content

Commit

Permalink
fix seed when running one fold
Browse files Browse the repository at this point in the history
  • Loading branch information
kervias committed Oct 18, 2023
1 parent fcea303 commit 5da8420
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion edustudio/traintpl/gd_traintpl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .base_traintpl import BaseTrainTPL
from edustudio.utils.common import UnifyConfig
from edustudio.utils.common import UnifyConfig, set_same_seeds
import torch
from torch.utils.data import DataLoader
from edustudio.utils.callback import History
from collections import defaultdict
import numpy as np


class GDTrainTPL(BaseTrainTPL):
default_cfg = {
'device': 'cuda:0',
Expand Down Expand Up @@ -98,6 +99,7 @@ def one_fold_start(self, fold_id):
fold_id (int): fold id
"""
self.logger.info(f"====== [FOLD ID]: {fold_id} ======")
set_same_seeds(self.traintpl_cfg['seed'])

def batch_dict2device(self, batch_dict):
dic = {}
Expand Down

0 comments on commit 5da8420

Please sign in to comment.