diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index ddd35b28..ebe5748c 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -32,9 +32,10 @@ class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, - validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), + validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), halve_lr_epochs=-1, check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True, use_cuda=False, callbacks=None): + """ :param DataSet train_data: the training data :param torch.nn.modules.module model: a PyTorch model @@ -47,6 +48,7 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch :param DataSet dev_data: the validation data :param str save_path: file path to save models :param Optimizer optimizer: an optimizer object + :param halve_lr_epochs: halve the learning rate if not imporving for [halve_lr_epochs] epochs. Default: -1 (never use it) :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\ `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是 @@ -108,6 +110,7 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch self.dev_data = dev_data # If None, No validation. self.model = model self.losser = losser + self.halve_lr_epochs = halve_lr_epochs self.metrics = metrics self.n_epochs = int(n_epochs) self.batch_size = int(batch_size) @@ -130,6 +133,9 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch self.use_tqdm = use_tqdm self.print_every = abs(self.print_every) + + for group in self.optimizer.param_groups: + self.lr = group['lr'] if self.dev_data is not None: self.tester = Tester(model=self.model, @@ -227,6 +233,7 @@ def _train(self): else: inner_tqdm = tqdm self.step = 0 + self.bad_valid = 0 start = time.time() total_steps = (len(self.train_data) // self.batch_size + int( len(self.train_data) % self.batch_size != 0)) * self.n_epochs @@ -302,6 +309,24 @@ def _do_validation(self, epoch, step): self.best_dev_perf = res self.best_dev_epoch = epoch self.best_dev_step = step + + # halve the learning rate if not improving for [halve_lr_epochs] epochs, and restart training from the best model. + else: + self.bad_valid += 1 + if self.halve_lr_epochs != -1: + if self.bad_valid >= self.halve_lr_epochs: + self.lr = self.lr / 2.0 + print("halve learning rate to {}".format(self.lr)) + model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) + load_succeed = self._load_model(self.model, model_name) + if load_succeed: + print("Reloaded the best model.") + else: + print("Fail to reload best model.") + self._set_lr(self.optimizer, self.lr) + self.bad_valid = 0 + print("bad valid: {}".format(self.bad_valid)) + # get validation results; adjust optimizer self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer) return res @@ -406,7 +431,13 @@ def _better_eval_result(self, metrics): else: is_better = False return is_better - + + def _set_lr(self, optimizer, lr): + # if self.optimizer == "YFOptimizer": + # optimizer.set_lr_factor(lr) + # else: + for group in optimizer.param_groups: + group['lr'] = lr DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2