diff --git a/mohou/trainer.py b/mohou/trainer.py index 6d06064..9d6bc9a 100644 --- a/mohou/trainer.py +++ b/mohou/trainer.py @@ -8,7 +8,17 @@ from datetime import datetime, timezone from enum import Enum from pathlib import Path -from typing import Dict, Generic, Iterable, List, Optional, Tuple, Type, TypeVar +from typing import ( + Callable, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, +) import matplotlib.pyplot as plt import numpy as np @@ -372,11 +382,11 @@ def train_lower( validate_loader: Iterable, config: TrainConfig = TrainConfig(), device: Optional[torch.device] = None, + is_stoppable: Optional[Callable[[TrainCache], None]] = None, ) -> None: r""" low-level train function that accepts train loader """ - log_package_version_info(logger, mohou) log_text_with_box(logger, "train log") logger.info("train start with config: {}".format(config)) @@ -432,6 +442,11 @@ def move_to_device(sample): logger.info("validate loss => {}".format(validate_ld_mean)) tcache.update_and_save(model, train_ld_mean, validate_ld_mean, project_path) + if is_stoppable is not None: + if is_stoppable(tcache): + logger.info("meet stop criterion. exit from trainning session.") + return + def train( project_path: Path, @@ -439,6 +454,7 @@ def train( dataset: Dataset, config: TrainConfig = TrainConfig(), device: Optional[torch.device] = None, + is_stoppable: Optional[Callable[[TrainCache], None]] = None, ) -> None: r""" higher-level train function that auto create dataloader from the dataset @@ -450,4 +466,12 @@ def train( validate_loader = DataLoader( dataset=dataset_validate, batch_size=config.batch_size, shuffle=True ) - train_lower(project_path, tcache, train_loader, validate_loader, config=config, device=device) + train_lower( + project_path, + tcache, + train_loader, + validate_loader, + config=config, + device=device, + is_stoppable=is_stoppable, + )