diff --git a/agent/base_training_agent.py b/agent/base_training_agent.py new file mode 100644 index 0000000..fbc4a26 --- /dev/null +++ b/agent/base_training_agent.py @@ -0,0 +1,132 @@ +import os + +from abc import ABC, abstractmethod + +from utils import save + +class MetaTrainingAgent: + def train_loop(self, model, train_loader, val_loader, agent): + """ The main training loop. + + Args: + model (nn.Module) + train_loader (torch.utils.data.DataLoader) + val_loader (torch.utils.data.DataLoader) + agent (Object) + """ + training_step = getattr(self, f"_{agent.agent_state}_training_step") + validate_step = getattr(self, f"_{agent.agent_state}_validate_step") + + best_val_performance = -float("inf") + for epoch in range(agent.start_epochs, agent.epochs): + agent.logger.info(f"Start to train for epoch {epoch}") + agent.logger.info(f"Learning Rate : {agent.optimizer.param_groups[0]['lr']:.8f}") + + training_step( + model, + train_loader, + agent, + epoch) + val_performance = validate_step( + model, + val_loader, + agent, + epoch) + + if val_performance > best_val_performance: + agent.logger.info(f"Best validation performance : {val_performance}. Save model!") + best_val_performance = val_performance + save( + model, + agent.config["experiment_path"]["best_checkpoint_path"], + agent.criterion, + agent.optimizer, + agent.lr_scheduler, + epoch + 1) + + save( + model, + os.path.join( + agent.config["experiment_path"]["checkpoint_root_path"], + f"{agent.agent_state}_{epoch}.pth"), + agent.criterion, + agent.optimizer, + agent.lr_scheduler, + epoch + 1) + + def _search_training_step(self, model, train_loader, agent, epoch): + """ The training step for searching process. Users should step the sampler + to decide how to train supernet and step the search strategy to search the architecture. + + Args: + model (nn.Module) + train_loader (torch.utils.data.DataLoader) + agent (Object): The search agent. + epoch (int) + """ + self._training_step(model, train_loader, agent, epoch) + + def _evaluate_training_step(self, model, train_loader, agent, epoch): + """ The training step for evaluating process (training from scratch). + + Args: + model (nn.Module) + train_loader (torch.utils.data.DataLoader) + agent (Object): The evaluate agent + epoch (int) + """ + self._training_step(model, train_loader, agent, epoch) + + def _search_validate_step(self, model, val_loader, agent, epoch): + """ The validate step for searching process. + + Args: + model (nn.Module) + val_loader (torch.utils.data.DataLoader) + agent (Object): The search agent. + epoch (int) + + Return: + evaluate_metric (float): The performance of the supernet + """ + return self._validate_step(model, val_loader, agent, epoch) + + def _evaluate_validate_step(self, model, val_loader, agent, epoch): + """ The training step for evaluating process (training from scratch). + + Args: + model (nn.Module) + val_loader (torch.utils.data.DataLoader) + agent (Object): The evaluate agent + epoch (int) + + Return: + evaluate_metric (float): The performance of the searched model. + """ + return self._validate_step(model, val_loader, agent, epoch) + + @abstractmethod + def _training_step(self, model, train_loader, agent, epoch): + """ + Args: + model (nn.Module) + train_loader (torch.utils.data.DataLoader) + agent (Object): The evaluate agent + epoch (int) + """ + raise NotImplementedError + + @abstractmethod + def _validate_step(self, model, val_loader, agent, epoch): + """ + Args: + model (nn.Module) + val_loader (torch.utils.data.DataLoader) + agent (Object): The evaluate agent + epoch (int) + + Return: + evaluate_metric (float): The performance of the searched model. + """ + raise NotImplementedError + return evaluate_metric diff --git a/agent/classification_agent/training_agent.py b/agent/classification_agent/training_agent.py index 1d0140a..f5ced8a 100644 --- a/agent/classification_agent/training_agent.py +++ b/agent/classification_agent/training_agent.py @@ -4,115 +4,14 @@ import torch from utils import AverageMeter, accuracy, save +from ..base_training_agent import MetaTrainingAgent -class CFTrainingAgent: +class CFTrainingAgent(MetaTrainingAgent): """The training agent to train the supernet and the searched architecture. By implementing TrainingAgent class, users can adapt the searching and evaluating agent into various tasks easily. """ - def train_loop(self, model, train_loader, val_loader, agent): - """ The main training loop. - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - val_loader (torch.utils.data.DataLoader) - agent (Object) - """ - # Utilize different step method based on differet agent state - training_step = getattr(self, f"_{agent.agent_state}_training_step") - validate_step = getattr(self, f"_{agent.agent_state}_validate_step") - - best_val_performance = -float("inf") - for epoch in range(agent.start_epochs, agent.epochs): - agent.logger.info(f"Start to train for epoch {epoch}") - agent.logger.info(f"Learning Rate : {agent.optimizer.param_groups[0]['lr']:.8f}") - - training_step( - model, - train_loader, - agent, - epoch) - val_performance = validate_step( - model, - val_loader, - agent, - epoch) - - if val_performance > best_val_performance: - agent.logger.info(f"Best validation performance : {val_performance}. Save model!") - best_val_performance = val_performance - save( - model, - agent.config["experiment_path"]["best_checkpoint_path"], - agent.criterion, - agent.optimizer, - agent.lr_scheduler, - epoch + 1) - - save( - model, - os.path.join( - agent.config["experiment_path"]["checkpoint_root_path"], - f"{agent.agent_state}_{epoch}.pth"), - agent.criterion, - agent.optimizer, - agent.lr_scheduler, - epoch + 1) - - - def _search_training_step(self, model, train_loader, agent, epoch): - """ The training step for searching process. Users should step the sampler - to decide how to train supernet and step the search strategy to search the architecture. - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - agent (Object): The search agent. - epoch (int) - """ - self._training_step(model, train_loader, agent, epoch) - - def _search_validate_step(self, model, val_loader, agent, epoch): - """ The validate step for searching process. - - Args: - model (nn.Module) - val_loader (torch.utils.data.DataLoader) - agent (Object): The search agent. - epoch (int) - - Return: - evaluate_metric (float): The performance of the supernet - """ - return self._validate(model ,val_loader, agent, epoch) - - def _evaluate_training_step(self, model, train_loader, agent, epoch): - """ The training step for evaluating process (training from scratch). - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - agent (Object): The evaluate agent - epoch (int) - """ - self._training_step(model, train_loader, agent, epoch) - - def _evaluate_validate_step(self, model, val_loader, agent, epoch): - """ The training step for evaluating process (training from scratch). - - Args: - model (nn.Module) - val_loader (torch.utils.data.DataLoader) - agent (Object): The evaluate agent - epoch (int) - - Return: - evaluate_metric (float): The performance of the searched model. - """ - return self._validate(model ,val_loader, agent, epoch) - @staticmethod def searching_evaluate(model, val_loader, device, criterion): """ Evaluating the performance of the supernet. The search strategy will evaluate @@ -150,6 +49,13 @@ def searching_evaluate(model, val_loader, device, criterion): def _training_step(self, model, train_loader, agent, epoch, print_freq=100): + """ + Args: + model (nn.Module) + train_loader (torch.utils.data.DataLoader) + agent (Object): The evaluate agent + epoch (int) + """ top1 = AverageMeter() top5 = AverageMeter() losses = AverageMeter() @@ -182,7 +88,7 @@ def _training_step(self, model, train_loader, agent, epoch, print_freq=100): top5.update(prec5.item(), N) if (step > 1 and step % print_freq == 0) or (step == len(train_loader) - 1): - agent.logger.info(f"Train : [{(epoch+1):3d}/{agent.epochs}]" + agent.logger.info(f"Train : [{(epoch+1):3d}/{agent.epochs}] " f"Step {step:3d}/{len(train_loader)-1:3d} Loss {losses.get_avg():.3f}" f"Prec@(1, 5) ({top1.get_avg():.1%}, {top5.get_avg():.1%})") @@ -196,7 +102,17 @@ def _training_step(self, model, train_loader, agent, epoch, print_freq=100): f"Time {time.time() - start_time:.2f}") - def _validate(self, model, val_loader, agent, epoch): + def _validate_step(self, model, val_loader, agent, epoch): + """ + Args: + model (nn.Module) + val_loader (torch.utils.data.DataLoader) + agent (Object): The evaluate agent + epoch (int) + + Return: + evaluate_metric (float): The performance of the searched model. + """ model.eval() start_time = time.time() diff --git a/agent/face_agent/training_agent.py b/agent/face_agent/training_agent.py index 966a1aa..c4ef44c 100644 --- a/agent/face_agent/training_agent.py +++ b/agent/face_agent/training_agent.py @@ -9,75 +9,14 @@ from utils import AverageMeter, save from .face_evaluate import evaluate +from ..base_training_agent import MetaTrainingAgent -class FRTrainingAgent: +class FRTrainingAgent(MetaTrainingAgent): """The training agent to train the supernet and the searched architecture. By implementing TrainingAgent class, users can adapt the searching and evaluating agent into various tasks easily. """ - def train_loop(self, model, train_loader, val_loader, agent): - """ The main training loop. - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - val_loader (torch.utils.data.DataLoader) - agent (Object) - """ - training_step = getattr(self, f"_{agent.agent_state}_training_step") - validate_step = getattr(self, f"_{agent.agent_state}_validate_step") - - best_val_performance = -float("inf") - for epoch in range(agent.start_epochs, agent.epochs): - agent.logger.info(f"Start to train for epoch {epoch}") - agent.logger.info(f"Learning Rate : {agent.optimizer.param_groups[0]['lr']:.8f}") - - training_step( - model, - train_loader, - agent, - epoch) - val_performance = validate_step( - model, - val_loader, - agent, - epoch) - - if val_performance > best_val_performance: - agent.logger.info(f"Best validation performance : {val_performance}. Save model!") - best_val_performance = val_performance - save( - model, - agent.config["experiment_path"]["best_checkpoint_path"], - agent.criterion, - agent.optimizer, - agent.lr_scheduler, - epoch + 1) - - save( - model, - os.path.join( - agent.config["experiment_path"]["checkpoint_root_path"], - f"{agent.agent_state}_{epoch}.pth"), - agent.criterion, - agent.optimizer, - agent.lr_scheduler, - epoch + 1) - - - def _search_training_step(self, model, train_loader, agent, epoch): - """ The training step for searching process. Users should step the sampler - to decide how to train supernet and step the search strategy to search the architecture. - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - agent (Object): The search agent. - epoch (int) - """ - self._training_step(model, train_loader, agent, epoch) - def _search_validate_step(self, model, val_loader, agent, epoch): """ The validate step for searching process. @@ -105,17 +44,6 @@ def _search_validate_step(self, model, val_loader, agent, epoch): return minus_losses_avg - def _evaluate_training_step(self, model, train_loader, agent, epoch): - """ The training step for evaluating process (training from scratch). - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - agent (Object): The evaluate agent - epoch (int) - """ - self._training_step(model, train_loader, agent, epoch) - def _evaluate_validate_step(self, model, val_loader, agent, epoch): """ The training step for evaluating process (training from scratch). @@ -210,6 +138,13 @@ def _training_step( agent, epoch, print_freq=100): + """ + Args: + model (nn.Module) + train_loader (torch.utils.data.DataLoader) + agent (Object): The evaluate agent + epoch (int) + """ losses = AverageMeter() model.train() diff --git a/doc/customize/agent.md b/doc/customize/agent.md index 678f981..e7e2948 100644 --- a/doc/customize/agent.md +++ b/doc/customize/agent.md @@ -33,100 +33,37 @@ import time import torch from utils import AverageMeter, save +from ..base_training_agent import MetaTrainingAgent -class {{customize_class}}TrainingAgent: +class {{customize_class}}TrainingAgent(MetaTrainingAgent): """The training agent to train the supernet and the searched architecture. By implementing TrainingAgent class, users can adapt the searching and evaluating agent into various tasks easily. """ - def train_loop(self, model, train_loader, val_loader, agent): - """ The main training loop. - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - val_loader (torch.utils.data.DataLoader) - agent (Object) - """ - # Utilize different step method based on differet agent state - training_step = getattr(self, f"_{agent.agent_state}_training_step") - validate_step = getattr(self, f"_{agent.agent_state}_validate_step") - - best_val_performance = -float("inf") - for epoch in range(agent.start_epochs, agent.epochs): - agent.logger.info(f"Start to train for epoch {epoch}") - agent.logger.info(f"Learning Rate : {agent.optimizer.param_groups[0]['lr']:.8f}") - - training_step( - model, - train_loader, - agent, - epoch) - val_performance = validate_step( - model, - val_loader, - agent, - epoch) - - if val_performance > best_val_performance: - agent.logger.info(f"Best validation performance : {val_performance}. Save model!") - best_val_performance = val_performance - save( - model, - agent.config["experiment_path"]["best_checkpoint_path"], - agent.optimizer, - agent.lr_scheduler, - epoch + 1) - - save( - model, - os.path.join( - agent.config["experiment_path"]["checkpoint_root_path"], - f"{agent.agent_state}_{epoch}.pth"), - agent.optimizer, - agent.lr_scheduler, - epoch + 1) - - - def _search_training_step(self, model, train_loader, agent, epoch): - """ The training step for searching process. Users should step the sampler - to decide how to train supernet and step the search strategy to search the architecture. - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - agent (Object): The search agent. - epoch (int) - """ - model.train() - for step, (X, y) in enumerate(train_loader): - agent._iteration_preprocess() - raise NotImplemented - - def _search_validate_step(self, model, val_loader, agent, epoch): - """ The validate step for searching process. + @staticmethod + def searching_evaluate(model, val_loader, device, criterion): + """ Evaluating the performance of the supernet. The search strategy will evaluate + the architectures by this static method to search. Args: model (nn.Module) val_loader (torch.utils.data.DataLoader) - agent (Object): The search agent. - epoch (int) + device (torch.device) + criterion (nn.Module) Return: - evaluate_metric (float): The performance of the supernet + evaluate_metric (float): The performance of the supernet. """ model.eval() with torch.no_grad(): for step, (X, y) in enumerate(val_loader): - agent._iteration_preprocess() - raise NotImplemented return evaluate_metric - def _evaluate_training_step(self, model, train_loader, agent, epoch): - """ The training step for evaluating process (training from scratch). + def _training_step(self, model, train_loader, agent, epoch, print_freq=100): + """ Args: model (nn.Module) train_loader (torch.utils.data.DataLoader) @@ -134,12 +71,16 @@ class {{customize_class}}TrainingAgent: epoch (int) """ model.train() + start_time = time.time() + for step, (X, y) in enumerate(train_loader): - raise NotImplemented + if agent.agent_state == "search": + agent._iteration_preprocess() - def _evaluate_validate_step(self, model, val_loader, agent, epoch): - """ The training step for evaluating process (training from scratch). + raise NotImplemented + def _validate_step(self, model, val_loader, agent, epoch): + """ Args: model (nn.Module) val_loader (torch.utils.data.DataLoader) @@ -150,30 +91,10 @@ class {{customize_class}}TrainingAgent: evaluate_metric (float): The performance of the searched model. """ model.eval() - with torch.no_grad(): - for step, (X, y) in enumerate(val_loader): - raise NotImplemented - return evaluate_metric - - @staticmethod - def searching_evaluate(model, val_loader, device, criterion): - """ Evaluating the performance of the supernet. The search strategy will evaluate - the architectures by this static method to search. - - Args: - model (nn.Module) - val_loader (torch.utils.data.DataLoader) - device (torch.device) - criterion (nn.Module) - - Return: - evaluate_metric (float): The performance of the supernet. - """ - model.eval() - with torch.no_grad(): - for step, (X, y) in enumerate(val_loader): - raise NotImplemented - return evaluate_metric + start_time = time.time() + if agent.agent_state == "search": + agent._iteration_preprocess() + raise NotImplemented ``` diff --git a/template/agent/training_template.py b/template/agent/training_template.py index f747824..1fb218e 100644 --- a/template/agent/training_template.py +++ b/template/agent/training_template.py @@ -4,100 +4,37 @@ import torch from utils import AverageMeter, save +from ..base_training_agent import MetaTrainingAgent -class {{customize_class}}TrainingAgent: +class {{customize_class}}TrainingAgent(MetaTrainingAgent): """The training agent to train the supernet and the searched architecture. By implementing TrainingAgent class, users can adapt the searching and evaluating agent into various tasks easily. """ - def train_loop(self, model, train_loader, val_loader, agent): - """ The main training loop. - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - val_loader (torch.utils.data.DataLoader) - agent (Object) - """ - # Utilize different step method based on differet agent state - training_step = getattr(self, f"_{agent.agent_state}_training_step") - validate_step = getattr(self, f"_{agent.agent_state}_validate_step") - - best_val_performance = -float("inf") - for epoch in range(agent.start_epochs, agent.epochs): - agent.logger.info(f"Start to train for epoch {epoch}") - agent.logger.info(f"Learning Rate : {agent.optimizer.param_groups[0]['lr']:.8f}") - - training_step( - model, - train_loader, - agent, - epoch) - val_performance = validate_step( - model, - val_loader, - agent, - epoch) - - if val_performance > best_val_performance: - agent.logger.info(f"Best validation performance : {val_performance}. Save model!") - best_val_performance = val_performance - save( - model, - agent.config["experiment_path"]["best_checkpoint_path"], - agent.optimizer, - agent.lr_scheduler, - epoch + 1) - - save( - model, - os.path.join( - agent.config["experiment_path"]["checkpoint_root_path"], - f"{agent.agent_state}_{epoch}.pth"), - agent.optimizer, - agent.lr_scheduler, - epoch + 1) - - - def _search_training_step(self, model, train_loader, agent, epoch): - """ The training step for searching process. Users should step the sampler - to decide how to train supernet and step the search strategy to search the architecture. - - Args: - model (nn.Module) - train_loader (torch.utils.data.DataLoader) - agent (Object): The search agent. - epoch (int) - """ - model.train() - for step, (X, y) in enumerate(train_loader): - agent._iteration_preprocess() - raise NotImplemented - - def _search_validate_step(self, model, val_loader, agent, epoch): - """ The validate step for searching process. + @staticmethod + def searching_evaluate(model, val_loader, device, criterion): + """ Evaluating the performance of the supernet. The search strategy will evaluate + the architectures by this static method to search. Args: model (nn.Module) val_loader (torch.utils.data.DataLoader) - agent (Object): The search agent. - epoch (int) + device (torch.device) + criterion (nn.Module) Return: - evaluate_metric (float): The performance of the supernet + evaluate_metric (float): The performance of the supernet. """ model.eval() with torch.no_grad(): for step, (X, y) in enumerate(val_loader): - agent._iteration_preprocess() - raise NotImplemented return evaluate_metric - def _evaluate_training_step(self, model, train_loader, agent, epoch): - """ The training step for evaluating process (training from scratch). + def _training_step(self, model, train_loader, agent, epoch, print_freq=100): + """ Args: model (nn.Module) train_loader (torch.utils.data.DataLoader) @@ -105,12 +42,16 @@ def _evaluate_training_step(self, model, train_loader, agent, epoch): epoch (int) """ model.train() + start_time = time.time() + for step, (X, y) in enumerate(train_loader): - raise NotImplemented + if agent.agent_state == "search": + agent._iteration_preprocess() - def _evaluate_validate_step(self, model, val_loader, agent, epoch): - """ The training step for evaluating process (training from scratch). + raise NotImplemented + def _validate_step(self, model, val_loader, agent, epoch): + """ Args: model (nn.Module) val_loader (torch.utils.data.DataLoader) @@ -121,29 +62,7 @@ def _evaluate_validate_step(self, model, val_loader, agent, epoch): evaluate_metric (float): The performance of the searched model. """ model.eval() - with torch.no_grad(): - for step, (X, y) in enumerate(val_loader): - raise NotImplemented - return evaluate_metric - - @staticmethod - def searching_evaluate(model, val_loader, device, criterion): - """ Evaluating the performance of the supernet. The search strategy will evaluate - the architectures by this static method to search. - - Args: - model (nn.Module) - val_loader (torch.utils.data.DataLoader) - device (torch.device) - criterion (nn.Module) - - Return: - evaluate_metric (float): The performance of the supernet. - """ - model.eval() - with torch.no_grad(): - for step, (X, y) in enumerate(val_loader): - raise NotImplemented - return evaluate_metric - - + start_time = time.time() + if agent.agent_state == "search": + agent._iteration_preprocess() + raise NotImplemented