From 5da8420a3bc79f01211b92c9a3269d5a60534bab Mon Sep 17 00:00:00 2001 From: kervias Date: Wed, 18 Oct 2023 23:09:51 +0800 Subject: [PATCH] fix seed when running one fold --- edustudio/traintpl/gd_traintpl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/edustudio/traintpl/gd_traintpl.py b/edustudio/traintpl/gd_traintpl.py index ac07d34..5b0c4e8 100644 --- a/edustudio/traintpl/gd_traintpl.py +++ b/edustudio/traintpl/gd_traintpl.py @@ -1,11 +1,12 @@ from .base_traintpl import BaseTrainTPL -from edustudio.utils.common import UnifyConfig +from edustudio.utils.common import UnifyConfig, set_same_seeds import torch from torch.utils.data import DataLoader from edustudio.utils.callback import History from collections import defaultdict import numpy as np + class GDTrainTPL(BaseTrainTPL): default_cfg = { 'device': 'cuda:0', @@ -98,6 +99,7 @@ def one_fold_start(self, fold_id): fold_id (int): fold id """ self.logger.info(f"====== [FOLD ID]: {fold_id} ======") + set_same_seeds(self.traintpl_cfg['seed']) def batch_dict2device(self, batch_dict): dic = {}